1

I have my data tensor which is of the shape [batch_size,512] and I have a constant matrix with values only of 0 and 1 which has the shape [256,512].

I would like to compute efficiently for each batch the sum of the products of my vector (second dimension of the data tensor) only for the entries which are 1 and not 0.

An explaining example: let us say I have 1-sized batch: the data tensor has the values [5,4,3,7,8,2] and my constant matrix has the values:

[0,1,1,0,0,0]
[1,0,0,0,0,0]
[1,1,1,0,0,1]

it means that I would like to compute for the first row 4*3, for the second 5 and for the third 5*4*3*2. and in total for this batch, I get 4*3+5+5*4*3*2 which equals to 137. Currently, I do it by iterating over all the rows, compute elementwise the product of my data and constant-matrix-row and then sum, which runs pretty slow.

Codevan
  • 538
  • 3
  • 20

1 Answers1

0

How about something like this:

import tensorflow as tf

# Two-element batch
data = [[5, 4, 3, 7, 8, 2],
        [4, 2, 6, 1, 6, 8]]
mask = [[0, 1, 1, 0, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
    # Data as tensors
    d = tf.constant(data, dtype=tf.int32)
    m = tf.constant(mask, dtype=tf.int32)
    # Tile data as needed
    dd = tf.tile(d[:, tf.newaxis], (1, tf.shape(m)[0], 1))
    mm = tf.tile(m[tf.newaxis, :], (tf.shape(d)[0], 1, 1))
    # Replace values with 1 wherever the mask is 0
    w = tf.where(tf.cast(mm, tf.bool), dd, tf.ones_like(dd))
    # Multiply row-wise and sum
    result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
    print(sess.run(result))

Output:

[137 400]

EDIT:

If you input data is a single vector then you would just have:

import tensorflow as tf

# Two-element batch
data = [5, 4, 3, 7, 8, 2]
mask = [[0, 1, 1, 0, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
    # Data as tensors
    d = tf.constant(data, dtype=tf.int32)
    m = tf.constant(mask, dtype=tf.int32)
    # Tile data as needed
    dd = tf.tile(d[tf.newaxis], (tf.shape(m)[0], 1))
    # Replace values with 1 wherever the mask is 0
    w = tf.where(tf.cast(m, tf.bool), dd, tf.ones_like(dd))
    # Multiply row-wise and sum
    result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
    print(sess.run(result))

Output:

137
jdehesa
  • 58,456
  • 7
  • 77
  • 121
  • what should I replace in your solution for batch size bigger than 1? @jdehesa – Codevan Mar 27 '18 at 11:58
  • @Codevan In that case, would you need the result per bach element or the total aggregated sum? – jdehesa Mar 27 '18 at 13:14
  • @Codevan I've edited for the per-element result, you can aggregate it later if needed. – jdehesa Mar 27 '18 at 13:16
  • Yes now it's good. Another issue - I receive -1 instead of 1 in w matrix, maybe it is because my mask (and data) are not integers by floats. How can I fix this? – Codevan Mar 27 '18 at 16:00
  • @Codevan If the mask really has only ones and zeros, I don't think it should make a difference... The issue could be in the data itself, not in the mask, it's hard to tell from the info. – jdehesa Mar 27 '18 at 16:04
  • Hi again, the current solution works for batch size > 1 but for batch size = 1 I still encounter problem in line `w = tf.where(tf.cast(mm, tf.bool), dd, tf.ones_like(dd))'. What am I missing? – Codevan Sep 16 '18 at 13:51
  • @Codevan I added a slightly modified version for one-dimensional input. – jdehesa Sep 17 '18 at 12:57
  • Thanks. There is no something else more generic which can work both when batch size is 1 or more? – Codevan Sep 18 '18 at 13:56
  • 1
    @Codevan The problem is that the number of dimensions of the input is different in each case. You can use the first code in all cases adding something like `d = tf.reshape(d, (-1, tf.shape(d)[-1]))` before the tiling. – jdehesa Sep 18 '18 at 14:09