For an objective, I am trying to compute the MultiHead Attention Matrix for a sparse matrix and a dense matrix. I understand that by default, the Keras MultiHead Attention API requires two dense matrices, and then returns the attention value after the Softmax operation with the Query, Keys and Values from the Vaswani et. al paper "Attention is all you need".
However, I have a use-case where I have a sparse and dense matrix, and I want to pass them to a MultiHead Attention layer as a Query and a Value respectively.
By default, there is no support, and converting to dense and back is not an option as the time complexity grows a lot.
Is there any way to override the internal applications not compatible with sparse-dense combinations, and maybe replace them with mixed APIs such as sparse_dense_matmul
for the Attention computation? Albeit, the documentation states that the matrix ranks must be 2 for sparse_dense_matmul
, which is why class overriding also seems not plausible to me directly, unless I write my own class sparse-dense computation block. Note: Rank for matmul is usually 3 for a transformer, as the shapes are in the format of (Batch Size, Sequence Length, Dim).
To given an example:
att = layers.MultiHeadAttention(num_heads=num_heads,
key_dim=embed_dim)
attn_output = att(query=inputs1, value=inputs2) # I would like to pass this query as sparse, this value as dense.
I appreciate any help.