12

I have a tensor of lengths in tensorflow, let's say it looks like this:

[4, 3, 5, 2]

I wish to create a mask of 1s and 0s whose number of 1s correspond to the entries to this tensor, padded by 0s to a total length of 8. I.e. I want to create this tensor:

[[1,1,1,1,0,0,0,0],
 [1,1,1,0,0,0,0,0],
 [1,1,1,1,1,0,0,0],
 [1,1,0,0,0,0,0,0]
]

How might I do this?

Evan Pu
  • 2,099
  • 5
  • 21
  • 36

3 Answers3

19

This can now be achieved by tf.sequence_mask. More details here.

Sonal Gupta
  • 571
  • 5
  • 10
15

This can be achieved using a variety of TensorFlow transformations:

# Make a 4 x 8 matrix where each row contains the length repeated 8 times.
lengths = [4, 3, 5, 2]
lengths_transposed = tf.expand_dims(lengths, 1)

# Make a 4 x 8 matrix where each row contains [0, 1, ..., 7]
range = tf.range(0, 8, 1)
range_row = tf.expand_dims(range, 0)

# Use the logical operations to create a mask
mask = tf.less(range_row, lengths_transposed)

# Use the select operation to select between 1 or 0 for each value.
result = tf.select(mask, tf.ones([4, 8]), tf.zeros([4, 8]))
mrry
  • 125,488
  • 26
  • 399
  • 400
  • Nice, but wouldn't it be more efficient to just cast the boolean values of `mask` to ints or use them directly instead of using `tf.select` on the last line? – Styrke May 10 '16 at 15:05
  • Sure! I guess my past experiences with C programming mean I never expect bool-to-int casting to work :), but I believe that's well defined in TensorFlow. – mrry May 13 '16 at 15:42
  • There is also `tf.zeros_like(mask)` which will create a zero-initialized tensor of shape of the tensor `mask`. – Lenar Hoyt Jun 20 '16 at 08:16
  • Do you need the tiling? Wouldn't `tf.less` automatically do the broadcasting for you? – Albert Dec 12 '16 at 15:26
  • Apparently you don't. Thanks for the suggestion! – mrry Dec 12 '16 at 15:54
0

I've got a bit shorter version, than previous answer. Not sure if it is more efficient or not

 def mask(self, seq_length, max_seq_length):
    return tf.map_fn(
        lambda x: tf.pad(tf.ones([x], dtype=tf.int32), [[0, max_seq_length - x]]),
        seq_length)
viktortnk
  • 2,739
  • 1
  • 19
  • 18