0

Is there any way to convert JAX npz pre-trained weights into kers/tf.keras h5 format weights? Couldn't find anything online.

Thanks

Innat
  • 16,113
  • 6
  • 53
  • 101
craft
  • 495
  • 5
  • 16

1 Answers1

0

The most straightforward way to convert from npz format to h5 format would be to load the data into memory and then rewrite it.

Here is a brief example:

import jax.numpy as jnp
from jax import random
import h5py

# Create some random weights
key = random.PRNGKey(1701)
weights = random.normal(key, shape=(100,))

# Save to an npz file
jnp.savez('weights.npz', weights=weights)

# Load the npz and convert to h5
data = jnp.load('weights.npz')
with h5py.File('weights.h5', 'w') as hf:
    hf.create_dataset('weights', data=data['weights'])

Note that the details of this will depend on the content of the npz file and the desired structure of the resulting h5 file.

jakevdp
  • 77,104
  • 11
  • 125
  • 160