0

I have a very simple GNN class (stripped to the bone to create a MRE):

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn.conv import GraphConv
from torch_geometric.nn.pool import global_add_pool


class GNN(torch.nn.Module):
    def __init__(
            self,
            num_classes,
            hidden_dim,
            node_features_dim,
    ):
        super(GNN, self).__init__()
        self.hidden_dim = hidden_dim

        self.conv1 = GraphConv(node_features_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)

        self.fc = Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_add_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc(x)

        return x

This is a graph classifier, which I will apply to binary classification problems. I want to test the forward method. I thus write:

import torch
from torch_geometric.datasets.tu_dataset import TUDataset
from torch_geometric.loader import DataLoader
from torch.nn.functional import binary_cross_entropy_with_logits
from .foobar import GNN


dataset = TUDataset(
    root=".",
    name="Mutagenicity",
).shuffle()
num_classes = dataset.num_classes
node_features_dim = dataset.num_features
hidden_dim = 64
batch_size = 128


def test_forward():
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch = next(iter(loader))
    batch.to(device)
    loss_fun = binary_cross_entropy_with_logits

    model = GNN(
        num_classes,
        hidden_dim,
        node_features_dim,
    ).to(device)

    out = model.forward(batch.x, batch.edge_index, batch.batch)
    target = batch.y.unsqueeze(1).float()
    loss = loss_fun(out, target)

    assert loss > 0, "loss must be positive"

(note that the test simply asserts that the loss must be positive: I'm open to suggestions for better tests). When I run the test, I get an exceedingly long error message. I copy it all for completeness, but feel free to jump at the end, where I provide a short summary.

============================= test session starts ==============================
collecting ... collected 1 item

graph_classification/test_foobar.py::test_forward FAILED                 [100%]
graph_classification/test_foobar.py:17 (test_forward)
def test_forward():
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        batch = next(iter(loader))
        batch.to(device)
        loss_fun = binary_cross_entropy_with_logits
    
        model = GNN(
            num_classes,
            hidden_dim,
            node_features_dim,
        ).to(device)
    
        out = model.forward(batch.x, batch.edge_index, batch.batch)
        target = batch.y.unsqueeze(1).float()
>       loss = loss_fun(out, target)

graph_classification/test_foobar.py:33: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

input = tensor([[ -5.1727,  -0.4420],
        [ -2.7096,   1.0948],
        [ -2.0776,  -0.5117],
        [ -8.1472,  -4.3884]...755,   9.7002],
        [-14.5554,  -8.8546],
        [ -2.5473,  -0.9506]], device='cuda:0', grad_fn=<AddmmBackward0>)
target = tensor([[0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      ....],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.]], device='cuda:0')
weight = None, size_average = None, reduce = None, reduction = 'mean'
pos_weight = None

    def binary_cross_entropy_with_logits(
        input: Tensor,
        target: Tensor,
        weight: Optional[Tensor] = None,
        size_average: Optional[bool] = None,
        reduce: Optional[bool] = None,
        reduction: str = "mean",
        pos_weight: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Function that measures Binary Cross Entropy between target and input
        logits.
    
        See :class:`~torch.nn.BCEWithLogitsLoss` for details.
    
        Args:
            input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).
            target: Tensor of the same shape as input with values between 0 and 1
            weight (Tensor, optional): a manual rescaling weight
                if provided it's repeated to match input tensor shape
            size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
                the losses are averaged over each loss element in the batch. Note that for
                some losses, there multiple elements per sample. If the field :attr:`size_average`
                is set to ``False``, the losses are instead summed for each minibatch. Ignored
                when reduce is ``False``. Default: ``True``
            reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
                losses are averaged or summed over observations for each minibatch depending
                on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
                batch element instead and ignores :attr:`size_average`. Default: ``True``
            reduction (string, optional): Specifies the reduction to apply to the output:
                ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
                ``'mean'``: the sum of the output will be divided by the number of
                elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
                and :attr:`reduce` are in the process of being deprecated, and in the meantime,
                specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
            pos_weight (Tensor, optional): a weight of positive examples.
                    Must be a vector with length equal to the number of classes.
    
        Examples::
    
             >>> input = torch.randn(3, requires_grad=True)
             >>> target = torch.empty(3).random_(2)
             >>> loss = F.binary_cross_entropy_with_logits(input, target)
             >>> loss.backward()
        """
        if has_torch_function_variadic(input, target, weight, pos_weight):
            return handle_torch_function(
                binary_cross_entropy_with_logits,
                (input, target, weight, pos_weight),
                input,
                target,
                weight=weight,
                size_average=size_average,
                reduce=reduce,
                reduction=reduction,
                pos_weight=pos_weight,
            )
        if size_average is not None or reduce is not None:
            reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
        else:
            reduction_enum = _Reduction.get_enum(reduction)
    
        if not (target.size() == input.size()):
>           raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
E           ValueError: Target size (torch.Size([128, 1])) must be the same as input size (torch.Size([128, 2]))

/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3095: ValueError

The main issue seems to be this error raised by binary_cross_entropy_with_logits

        if not (target.size() == input.size()):
>           raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
E           ValueError: Target size (torch.Size([128, 1])) must be the same as input size (torch.Size([128, 2]))

In other words, it looks like the issue is that target and out are not the same shape. However, I don't understand: out should be of shape (128,2), because for each sample in the batch of size 2, I should return two logits (one for the positive class, one for the negative class). Instead, I would expect target to be of shape (128,1), since each sample will have one and only one class. Clearly there's some mistake in my reasoning. Can you help me fix the error?

DeltaIV
  • 4,773
  • 12
  • 39
  • 86

0 Answers0