4

I am interested about training a neural network using JAX. I had a look on tf.data.Dataset, but it provides exclusively tf tensors. I looked for a way to change the dataset into JAX numpy array and I found a lot of implementations that use Dataset.as_numpy_generator() to turn the tf tensors to numpy arrays. However I wonder if it is a good practice, as numpy arrays are stored in CPU memory and it is not what I want for my training (I use the GPU). So the last idea I found is to manually recast the arrays by calling jnp.array but it is not really elegant (I am afraid about the copy in GPU memory). Does anyone have a better idea for that?

Quick code to illustrate:

import os
import jax.numpy as jnp
import tensorflow as tf

def generator():
    for _ in range(2):
        yield tf.random.uniform((1, ))

ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
                                    output_shapes=tf.TensorShape([1]))

ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)

for i, batch in enumerate(ds1):
    print(type(batch))

for i, batch in enumerate(ds2):
    print(type(jnp.array(batch)))

# returns:

<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant
desertnaut
  • 57,590
  • 26
  • 140
  • 166
Valentin Goldité
  • 1,040
  • 4
  • 13

2 Answers2

4

Both tensorflow and JAX have the ability to convert arrays to dlpack tensors without copying memory, so one way you can create a JAX array from a tensorflow array without copying the underlying data buffer is to do it via dlpack:

import numpy as np
import tensorflow as tf
import jax.dlpack

tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)

np.testing.assert_array_equal(tf_arr, jax_arr)

By doing the round-trip to JAX, you can compare unsafe_buffer_pointer() to ensure that the arrays point at the same buffer, rather than copying the buffer along the way:

def tf_to_jax(arr):
  return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))

def jax_to_tf(arr):
  return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))

jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)

print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • 1
    Thank you very much! Do you know if it is possible to run the function once in all the dataset? I tried the .map() method but it fails because ```The argument to `to_dlpack` must be a TF tensor, not Python object``` even if my dataset is composed of tf.Tensor... – Valentin Goldité Oct 31 '21 at 23:32
  • I don't know what you mean by "run the function once in all the dataset" – jakevdp Nov 01 '21 at 15:17
  • something like ```dataset.map(tf_to_jax)``` to avoid calling the function at each iteration of the dataset – Valentin Goldité Nov 03 '21 at 00:34
  • 1
    No, I don't think tensorflow has support for anything like that. – jakevdp Nov 03 '21 at 03:14
1

From Flax example:

https://github.com/google/flax/blob/6ae22681ef6f6c004140c3759e7175533bda55bd/examples/imagenet/train.py#L183

def prepare_tf_data(xs):
  local_device_count = jax.local_device_count()
  def _prepare(x):
    x = x._numpy() 
    return x.reshape((local_device_count, -1) + x.shape[1:])
  return jax.tree_util.tree_map(_prepare, xs)

it = map(prepare_tf_data, ds)
it = jax_utils.prefetch_to_device(it, 2)
Mutlu Simsek
  • 1,088
  • 14
  • 22