Using the following code, I am implementing a global self-attention for sparse
input using Minkowski_Engine
. I am getting a bit worse result than the model without attention and wonder why this happened. Typically since in the last line of the code I have used a skip connection, I am expecting the result should be the same as without attention in the worst case, but actually, it is not.
First I pass the input through linear functions and create query
, key
and value
. Then I multiply the query
and key
to generate the attention map and finally I project the attention weights to the value. What I am doing wrong in the code?
class SelfAttention(nn.Module):
""" Self attention Model"""
def __init__(self, in_dim):
super(SelfAttention, self).__init__()
self.linear_query = ME.MinkowskiLinear(in_dim, in_dim)
self.linear_key = ME.MinkowskiLinear(in_dim, in_dim)
self.linear_value = ME.MinkowskiLinear(in_dim, in_dim)
self.pooling = ME.MinkowskiGlobalPooling()
self.normalized = nn.Softmax(dim=-1)
def forward(self, x):
identity = x
Q = self.linear_query(x)
K = self.pooling(x)
K = self.linear_key(K)
V = self.pooling(x)
V = self.linear_value(V)
K_feat_t = K.features.T
# KT = ME.SparseTensor(features=K_feat_t, device = a.device, coordinate_map_key=a.coordinate_map_key, coordinate_manager=a.coordinate_manager)
attmap = torch.matmul(Q.F, K_feat_t)/(math.sqrt(Q.F.size(1)))
out = self.normalized(attmap)
out = torch.matmul(out, V.F)
out = ME.SparseTensor(features=out, tensor_stride=x.tensor_stride, device = x.device, coordinate_map_key=x.coordinate_map_key, coordinate_manager=x.coordinate_manager)
out = out + identity
return out