I want to multiply two tensors, one sparse and the other dense. The sparse one is 3D and the dense one 2D. I cannot convert the sparse tensor to a dense tensor (i.e., avoid using tf.sparse.to_dense(...)
).
My multiplication is given by the following law:
C[i,j] = \sum_k A[i,k]*B[i,k,j]
where C = A*B and A and B are the dense and sparse tensors described above.
An example of execution in TF would be as follows:
# Example
# inputs
A = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], tf.float32)
B = tf.sparse.SparseTensor(values=tf.constant([1,1,1,1,2,1,1,1,1,1,-1], tf.float32),indices=[[0,0,1],[0,1,2],[0,2,0],[1,0,0],[1,1,1],[1,1,2],[1,2,2],[2,0,2],[2,1,1],[2,2,1],[2,2,2]], dense_shape=[3,3,3])
# output
C = tf.constant([[3, 1, 2],
[4, 10, 11],
[9, 8, -1]], tf.float32)
tf.einsum
does not support sparse tensors.
I have a version where I slice the 3D sparse tensor B into a collection of 2D sparse matrices, B[0,:,:], B[1,:,:],B[2,:,:],...
, and multiply each row of the dense matrix A, A[i,:]
, with each 2D sliced sparse matrix B[i,:,:]
applying the tf.sparse.sparse_dense_matmul(A[i,:],B[i,:,:])
function (with the corresponding reshapes after the slicing to have 2D tensors as arguments of tf.sparse.sparse_dense_matmul
). Then, I stack all the vector results to assemble the C matrix. This procedure is slow and breaks the tensorial structure of B. I want to perform the same operation by applying ONLY Tensorflow functions (avoiding for loops to slice and break the sparse tensor to later reassamble the result by stacking). Then, this should work with Keras as a layer of a Neural Network ([A,B] is the batched list of inputs, C = A*B is the batched output of the layer). Breaking the tensors to compute the multiplications is crazy for the training in the compiled graph!
Any ideas? Does there exist any tf.sparse.einsum
-like function for sparse tensors?
If I converted B to dense tensor, it would be super straightforward by applying tf.einsum(A,B,'ik,ikj->ij')
. However, I cannot afford to lose the sparsity of B.
Thank you. Regards,