I'm working with linear transformation in the form of Y=Q(X+A), where X is the input tensor and Y is the output, Q and A are two tensors to be learned. Q is an arbitrary tensor, therefore I can use nn.Linear
. But A is a (differentiable) tensor that has some specific pattern, as a short example,
A = [[a0,a1,a2,a2,a2],
[a1,a0,a1,a2,a2],
[a2,a1,a0,a1,a2],
[a2,a2,a1,a0,a1],
[a2,a2,a2,a1,a0]].
So I cannot define such a pattern in nn.Linear
. Is there any way to define such a tensor in Pytorch?