37

Is there any way to perform a dictionary lookup based on a String tensor in Tensorflow?

In plain Python, I'd do something like

value = dictionary[key]

. Now I'd like to do the same thing at Tensorflow runtime, when I have my key as a String tensor. Something like

value_tensor = tf.dict_lookup(string_tensor)

would be nice.

mackcmillion
  • 830
  • 2
  • 10
  • 16

4 Answers4

35

If you want to run this with new TF 2.x code with eager execution enabled by default. Below is the quick code snippet.

import tensorflow as tf

# build a lookup table
table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant([0, 1, 2, 3]),
        values=tf.constant([10, 11, 12, 13]),
    ),
    default_value=tf.constant(-1),
    name="class_weight"
)

# now let us do a lookup
input_tensor = tf.constant([0, 0, 1, 1, 2, 2, 3, 3])
out = table.lookup(input_tensor)
print(out)

Output:

tf.Tensor([10 10 11 11 12 12 13 13], shape=(8,), dtype=int32)
Praveen Kulkarni
  • 2,816
  • 1
  • 23
  • 39
  • 2
    For strings, use `keys=['a', 'b', 'c']` and change `input_tensor` to be something like `tf.constant(['a', 'a', 'c', 'b'])`. – Samir Jul 03 '19 at 09:31
  • 1
    Any ideas on how to use values as arrays ? Like `values = tf.constant( [0.1,0.2, 0.3] , [0.2,0.5,0.7] , [0.3,0.4,0.5] ]`. I'd like to map string key to array values. Thanks in advance for going through this ! – Amith Adiraju Jun 03 '21 at 22:49
29

You might find tensorflow.contrib.lookup helpful: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py

https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable

In particular, you can do:

table = tf.contrib.lookup.HashTable(
  tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1
)
out = table.lookup(input_tensor)
table.init.run()
print out.eval()
Saurfang
  • 685
  • 7
  • 14
0

tf.gather can help you, but it only gets values of list. You can convert dictionary into key and value lists, and then apply tf.gather. Example:

# Your dict
dict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}
# concrete query
query_list = ['a', 'c']

# unpack key and value lists
key, value = list(zip(*dict_.items()))
# map query list to list -> [0, 2]
query_list = [i for i, s in enumerate(key) if s in query_list]

# query as tensor
query = tf.placeholder(tf.int32, shape=[None])
# convert value list to tensor
vl_tf = tf.constant(value)
# get value
my_vl = tf.gather(vl_tf, query)

# session run
sess = tf.InteractiveSession()
sess.run(my_vl, feed_dict={query:query_list})
-12

TensorFlow is a data flow language with no support for data structures other than tensors. There is no map or dictionary type. However, depending on what you need, when you're using the Python wrapper it is possible to maintain a dictionary in the driver process, which executes in Python, and use it to interact with the TensorFlow graph execution. For example, you could execute one step of a TensorFlow graph within a session, return a string value to the Python driver, use it as a key into a dictionary in the driver, and use the retrieved value to determine the next computation to be requested from the session. This is probably not a good solution if the speed of these dictionary lookups is performance critical.

Paul Tucker
  • 860
  • 5
  • 2