I have downloaded a jax numpy weight file with npz suffix, but when I tried to convert it to h5 file I recieved this error:
import jax.numpy as jnp
import h5py
import tensorflow as tf
BASE_URL = "https://github.com/faustomorales/vit-keras/releases/download/dl"
size = "B_16"
weights = "imagenet21k"
fname = f"ViT-{size}_{weights}.npz"
origin = f"{BASE_URL}/{fname}"
# saved weight file in local path "~/.keras/weights/"
local_filepath = tf.keras.utils.get_file(fname, origin, cache_subdir="weights")
# Load the npz and convert to h5
jax_file = jnp.load('ViT-B_16_imagenet21k.npz')
with h5py.File('ViT-B_16_imagenet21k.h5', 'w') as hf:
hf.create_dataset('weights', data=jax_file)
the npz file will be downloaded in local path "~/.keras/weights/".
my error is:
*** TypeError: No conversion path for dtype: dtype('<U71')
Traceback (most recent call last):
File "/home/media/m_env/lib/python3.8/site-packages/h5py/_hl/group.py", line 161, in create_dataset
dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds)
File "/home/media/m_env/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 88, in make_new_dset
tid = h5t.py_create(dtype, logical=1)
File "h5py/h5t.pyx", line 1663, in h5py.h5t.py_create
File "h5py/h5t.pyx", line 1687, in h5py.h5t.py_create
File "h5py/h5t.pyx", line 1753, in h5py.h5t.py_create
my question is, How can I convert a jax file into h5 file, correctly?
Note1: the output of dir(jax_file)
:
dir(jax_file)
['__abstractmethods__', '__class__', '__contains__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_files', 'allow_pickle', 'close', 'f', 'fid', 'files', 'get', 'items', 'keys', 'pickle_kwargs', 'values', 'zip']
Note2: type of jax_file is:
type(jax_file)
<class 'numpy.lib.npyio.NpzFile'>
Note3: my tensorflow version 2.9.1
Any help will be appreciated.