0

I have downloaded a jax numpy weight file with npz suffix, but when I tried to convert it to h5 file I recieved this error:

import jax.numpy as jnp
import h5py
import tensorflow as tf


BASE_URL = "https://github.com/faustomorales/vit-keras/releases/download/dl"

size = "B_16"
weights = "imagenet21k"
fname = f"ViT-{size}_{weights}.npz"
origin = f"{BASE_URL}/{fname}"

# saved weight file in local path "~/.keras/weights/"
local_filepath = tf.keras.utils.get_file(fname, origin, cache_subdir="weights")

# Load the npz and convert to h5
jax_file = jnp.load('ViT-B_16_imagenet21k.npz')
with h5py.File('ViT-B_16_imagenet21k.h5', 'w') as hf:
    hf.create_dataset('weights', data=jax_file)

the npz file will be downloaded in local path "~/.keras/weights/".

my error is:

*** TypeError: No conversion path for dtype: dtype('<U71')
Traceback (most recent call last):
  File "/home/media/m_env/lib/python3.8/site-packages/h5py/_hl/group.py", line 161, in create_dataset
    dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds)
  File "/home/media/m_env/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 88, in make_new_dset
    tid = h5t.py_create(dtype, logical=1)
  File "h5py/h5t.pyx", line 1663, in h5py.h5t.py_create
  File "h5py/h5t.pyx", line 1687, in h5py.h5t.py_create
  File "h5py/h5t.pyx", line 1753, in h5py.h5t.py_create

my question is, How can I convert a jax file into h5 file, correctly?

Note1: the output of dir(jax_file):

dir(jax_file)
['__abstractmethods__', '__class__', '__contains__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_files', 'allow_pickle', 'close', 'f', 'fid', 'files', 'get', 'items', 'keys', 'pickle_kwargs', 'values', 'zip']

Note2: type of jax_file is:

type(jax_file)
<class 'numpy.lib.npyio.NpzFile'>

Note3: my tensorflow version 2.9.1

Any help will be appreciated.

MediaJ
  • 41
  • 7

1 Answers1

0

The data argument of the create_dataset method should be a numpy array; you are passing a NpzFile object which is essentially a reference to a group of arrays stored on-disk. To turn the contents into an h5py file with multiple datasets, you'll have to call create_dataset multiple times.

For example, it might look like this:

import h5py
import numpy as np

# example npz file
np.savez('data.npz', x=np.arange(4), y=np.arange(10))

# load npz file and write to h5py file
np_file = np.load('data.npz')
with h5py.File('data.h5', 'w') as h5_file:
  for name, arr in np_file.items():
    h5_file.create_dataset(name, data=arr)

with h5py.File('data.h5') as f:
  print(f.keys())
  print(f['x'])
  print(f['y'])
<KeysViewHDF5 ['x', 'y']>
<HDF5 dataset "x": shape (4,), type "<i8">
<HDF5 dataset "y": shape (10,), type "<i8">

You can read more about how to interact with npz files in the numpy.savez docs.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Your are right, the input should be a numpy array. but for loading h5 weights, I need a unified h5 file. How can I create a unified h5 file? – MediaJ Jan 26 '23 at 12:45
  • If you want multiple datasets in your h5py file, perhaps you should call `create_dataset` multiple times? See my edited answer above. – jakevdp Jan 26 '23 at 14:20