3

I want to check if any of a set of given values are contained in a sparse tensor. The sparse tensor is called labels and has just one dimension containing a list of ids.

In the end this seems like a simple set intersection problem, so I tried this.

sparse_ids = load_ids_as_sparse_tensor()
wanted_ids = tf.constant([34, 56, 12])
intersection = tf.sets.set_intersection(
    wanted_ids,
    tf.cast(sparse_ids.values, tf.int32)
)
contains_any_wanted_ids = tf.not_equal(tf.size(intersection), 0)

However, I am getting this error:

ValueError: Shape must be at least rank 2 but is rank 1 for 'DenseToDenseSetOperation' (op: 'DenseToDenseSetOperation') with input shapes: [3], [?].

Any ideas?

jaime
  • 520
  • 5
  • 15

1 Answers1

2

The following code works. However, I am not sure whether the result is what you want.

import tensorflow as tf
a = tf.constant([34, 56, 12])
b = tf.constant([56])
intersection = tf.sets.set_intersection(a[None,:],b[None,:])
sess=tf.Session()
sess.run(intersection)

Output:

SparseTensorValue(indices=array([[0, 0]], dtype=int64), values=array([56]), dense_shape=array([1, 1], dtype=int64))

Qin Heyang
  • 1,456
  • 1
  • 16
  • 18