0

I am trying to implement a multitask neural network used by a paper but am quite unsure how I should code the multitask network because the authors did not provide code for that part.

The network architecture looks like (paper):

Network architecture

To make it simpler, the network architecture could be generalized as (For demo I changed their more complicated operation for the pair of individual embeddings to concatenation): A simpler version

The authors are summing the loss from the individual tasks and the pairwise tasks, and using the total loss to optimize the parameters for the three networks (encoder, MLP-1, MLP-2) in each batch, but I am kind of at sea as to how different types of data are combined in a single batch to feed into two different networks that share an initial encoder. I tried to search for other networks with similar structure but did not find any sources. Would appreciate any thoughts!

user48867
  • 141
  • 1
  • 9

1 Answers1

1

This is actually a common pattern. It would be solved by code like the following.

class Network(nn.Module):
   def __init__(self, ...):
      self.encoder = DrugTargetInteractiongNetwork()
      self.mlp1 = ClassificationMLP()
      self.mlp2 = PairwiseMLP()

   def forward(self, data_a, data_b):
      a_encoded = self.encoder(data_a)
      b_encoded = self.encoder(data_b)

      a_classified = self.mlp1(a_encoded)
      b_classified = self.mlp1(b_encoded)

      # let me assume data_a and data_b are of shape
      # [batch_size, n_molecules, n_features].
      # and that those n_molecules are not necessarily
      # equal.
      # This can be generalized to more dimensions.
      a_broadcast, b_broadcast = torch.broadcast_tensors(
         a_encoded[:, None, :, :],
         b_encoded[:, :, None, :],
      )

      # this will work if your mlp2 accepts an arbitrary number of
      # learding dimensions and just broadcasts over them. That's true
      # for example if it uses just Linear and pointwise
      # operations, but may fail if it makes some specific assumptions
      # about the number of dimensions of the inputs
      pairwise_classified = self.mlp2(a_broadcast, b_broadcast)

      # if that is a problem, you have to reshape it such that it
      # works. Most torch models accept at least a leading batch dimension
      # for vectorization, so we can "fold" the pairwise dimension
      # into the batch dimension, presenting it as
      # [batch*n_mol_1*n_mol_2, n_features]
      # to mlp2 and then recover it back
      B, N1, N_feat = a_broadcast.shape
      _B, N2, _N_feat = b_broadcast.shape
      a_batched = a_broadcast.reshape(B*N1*N2, N_feat)
      b_batched = b_broadcast.reshape(B*N1*N2, N_feat)
      # above, -1 would suffice instead of B*N1*N2, just being explicit
      batch_output = self.mlp2(a_batched, b_batched)

      # this should be exactly the same as `pairwise_classified`
      alternative_classified = batch_output.reshape(B, N1, N2, -1)

      return a_classified, b_classified, pairwise_classified
Jatentaki
  • 11,804
  • 4
  • 41
  • 37
  • Thank you for the very detailed demonstration (and helpful comments)! I was perhaps overthinking it too much. Will try it later. – user48867 Nov 11 '21 at 23:03