0

For an objective, I am trying to compute the MultiHead Attention Matrix for a sparse matrix and a dense matrix. I understand that by default, the Keras MultiHead Attention API requires two dense matrices, and then returns the attention value after the Softmax operation with the Query, Keys and Values from the Vaswani et. al paper "Attention is all you need".

However, I have a use-case where I have a sparse and dense matrix, and I want to pass them to a MultiHead Attention layer as a Query and a Value respectively.

By default, there is no support, and converting to dense and back is not an option as the time complexity grows a lot. Is there any way to override the internal applications not compatible with sparse-dense combinations, and maybe replace them with mixed APIs such as sparse_dense_matmul for the Attention computation? Albeit, the documentation states that the matrix ranks must be 2 for sparse_dense_matmul, which is why class overriding also seems not plausible to me directly, unless I write my own class sparse-dense computation block. Note: Rank for matmul is usually 3 for a transformer, as the shapes are in the format of (Batch Size, Sequence Length, Dim).

To given an example:

att = layers.MultiHeadAttention(num_heads=num_heads,
                                             key_dim=embed_dim)
attn_output = att(query=inputs1, value=inputs2) # I would like to pass this query as sparse, this value as dense.

I appreciate any help.

Arka Mukherjee
  • 2,083
  • 1
  • 13
  • 27
  • Partial/Difficult solution that I found: The only way I found is using Tensorflow's CSR Matices, and writing a custom Transformer using Sparse-Dense Matmuls. CSR Matrices support rank 3 sparse-dense matmuls, albeit it's slower than just `tf.sparse`. – Arka Mukherjee Jul 21 '22 at 20:33

1 Answers1

0

You can take a look into official repos that published the implementation of sparce attention such as sparse transformer

Arij Aladel
  • 356
  • 1
  • 3
  • 10