Is there any way to convert JAX
npz pre-trained weights into kers/tf.keras
h5 format weights? Couldn't find anything online.
Thanks
The most straightforward way to convert from npz
format to h5
format would be to load the data into memory and then rewrite it.
Here is a brief example:
import jax.numpy as jnp
from jax import random
import h5py
# Create some random weights
key = random.PRNGKey(1701)
weights = random.normal(key, shape=(100,))
# Save to an npz file
jnp.savez('weights.npz', weights=weights)
# Load the npz and convert to h5
data = jnp.load('weights.npz')
with h5py.File('weights.h5', 'w') as hf:
hf.create_dataset('weights', data=data['weights'])
Note that the details of this will depend on the content of the npz file and the desired structure of the resulting h5 file.