0

While working on a problem related to question-answering(MRC), I have implemented two different architectures that independently give two tensors (probability distribution over the tokens). Both the tensors are of dimension (batch_size,512). I wish to obtain the final output of the form (batch_size,512). How can I combine the two tensors using trainable weights and then train the model on the final prediction?

Edit (Additional Information):

So in the forward function of my NN model, I have used BERT model to encode the 512 tokens. These encodings are 768 dimensional. These are then passed to a Linear layer nn.Linear(768,1) to output a tensor of shape (batch_size,512,1). Apart from this I have another model built on top of the BERT encodings that also yields a tensor of shape (batch_size, 512, 1). I wish to combine these two tensors to finally get a tensor of shape (batch_size, 512, 1) which can be trained against the output logits of the same shape using CrossEntropyLoss.

Please share the PyTorch code snippet if possible.

1 Answers1

0

Assume your two vectors are V1 and V2. You need to combine them (ensembling) to get a new vector. You can use a weighted sum like this:

alpha = sigmoid(alpha)
V_final = alpha * V1 + (1 - alpha) * V2

where alpha is a learnable scaler. The sigmoid is to bound alpha between 0 and 1, and you can initialise alpha = 0 so that sigmoid(alpha) is half, meaning you are adding V1 and V2 with equal weights.

This is a linear combination, and there can be non-linear versions as well. You can have a nonlinear layer that accepts (V1;V2) (the concatenation) and outputs a softmaxed output as well e.g. softmax(W * (V1;V2) + b).

Ash
  • 3,428
  • 1
  • 34
  • 44
  • Okay, so if I work with pytorch, I simply have to define a self.alpha=nn.Parameter(tensor.zeros(1)). Then in the forward function take the sigmod of alpha and use the above equations. So the final output I obtain I will directly use CrossEntropyLoss with that or I should take its log before passing to NLLLoss? – Kushagra Bhatia Jun 21 '20 at 10:53
  • if V_final is progs, you'll need to take the logit and then feed it to nlllloss. – Ash Jun 22 '20 at 14:29