When I modify a geometric network to address the PPI problem, I found that 'super()' can not be called in some circumstances.
The following way can lead to an error as:
TypeError: super(type, obj): obj must be an instance or subtype of type
def forward(self, batch, level='residue', **kwargs):
out = torch.cat([super().forward(graph, scatter_mean=False, dense=True) for graph in batch], dim=-1)
if level == 'atom': out = out[batch.ca_idx + batch.ptr[:-1]]
return torch.sigmoid(out)
Notably, the batch has two items, i.e., two torch_geometric graphs.
However, the following way works fine for me.
def forward(self, batch, level='residue', **kwargs):
out1 = super().forward(batch[0], scatter_mean=False, dense=True)
out2 = super().forward(batch[1], scatter_mean=False, dense=True)
out = torch.cat([out1, out2], dim=-1)
if level == 'atom': out = out[batch.ca_idx + batch.ptr[:-1]]
return torch.sigmoid(out)