TL;DR
It is recommended to use Option C, defined below.
Full answer
The tf.data.Dataset
object is created so that all the objects don't have to be loaded into memory. Becuase of that, using tf.gather
is not going to work by default. There are three options you can go with:
Option A: Load ds into memory + tf.gather
If you wanted to use gather you would have to load the entire dataset into memory, and select a subset:
m_X_ds = list(m_X_ds) # Load into memory.
m_X_ds = tf.gather(m_X_ds, arr)) # Gather as usual.
print(m_X_ds)
# Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>
Note, that this is not always possible, especially with huge datasets.
Option B: Iterate over the dataset, and filter undesired samples
You could also iterate over the dataset and manually choose the samples with desired index. This would be possible via combination of filter and tf.py_function
m_X_ds = m_X_ds.enumerate() # Create index,value pairs in the dataset.
# Create filter function:
def filter_fn(idx, value):
return idx in arr
# The above is not going to work in graph mode
# We are wrapping it with py_function to execute it eagerly
def py_function_filter(idx, value):
return tf.py_function(filter_fn, (idx, value), tf.bool)
# Filter the dataset as usual:
filtered_ds = m_X_ds.filter(py_function_filter)
Option C: combine option B with tf.lookup.StaticHashTable
Option B is good apart from the fact that you can expect a performance hit when converting graph tensor -> eager tensor -> graph tensor. tf.py_function
is useful but at a cost.
Instead, we could declare a dictionary where desired indices would return true and non present indices could return false. This dict could look like this
my_table = {3: True, 4: True, 5: True}.
We cannot use tensor as dictionary key, but we can declare a tensorflow's hash table to let us check for "good" indices.
m_X_ds = m_X_ds.enumerate() # Do not repeat this if executed in Option B.
keys_tensor = tf.constant(arr)
vals_tensor = tf.ones_like(keys_tensor) # Ones will be casted to True.
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=0) # If index not in table, return 0.
def hash_table_filter(index, value):
table_value = table.lookup(index) # 1 if index in arr, else 0.
index_in_arr = tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
return index_in_arr
filtered_ds = m_X_ds.filter(hash_table_filter)
Regardless of Option B or C, all there is left is to drop the index from your fileterd dataset. We can use simple map, with lambda function:
final_ds = filtered_ds.map(lambda idx,value: value)
for entry in final_ds:
print(entry)
# tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(13, shape=(), dtype=int32)
# tf.Tensor(6, shape=(), dtype=int32)
Best of luck.