I have a very simple GNN class (stripped to the bone to create a MRE):
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn.conv import GraphConv
from torch_geometric.nn.pool import global_add_pool
class GNN(torch.nn.Module):
def __init__(
self,
num_classes,
hidden_dim,
node_features_dim,
):
super(GNN, self).__init__()
self.hidden_dim = hidden_dim
self.conv1 = GraphConv(node_features_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
self.fc = Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_add_pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc(x)
return x
This is a graph classifier, which I will apply to binary classification problems. I want to test the forward
method. I thus write:
import torch
from torch_geometric.datasets.tu_dataset import TUDataset
from torch_geometric.loader import DataLoader
from torch.nn.functional import binary_cross_entropy_with_logits
from .foobar import GNN
dataset = TUDataset(
root=".",
name="Mutagenicity",
).shuffle()
num_classes = dataset.num_classes
node_features_dim = dataset.num_features
hidden_dim = 64
batch_size = 128
def test_forward():
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
batch = next(iter(loader))
batch.to(device)
loss_fun = binary_cross_entropy_with_logits
model = GNN(
num_classes,
hidden_dim,
node_features_dim,
).to(device)
out = model.forward(batch.x, batch.edge_index, batch.batch)
target = batch.y.unsqueeze(1).float()
loss = loss_fun(out, target)
assert loss > 0, "loss must be positive"
(note that the test simply asserts that the loss must be positive: I'm open to suggestions for better tests). When I run the test, I get an exceedingly long error message. I copy it all for completeness, but feel free to jump at the end, where I provide a short summary.
============================= test session starts ==============================
collecting ... collected 1 item
graph_classification/test_foobar.py::test_forward FAILED [100%]
graph_classification/test_foobar.py:17 (test_forward)
def test_forward():
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
batch = next(iter(loader))
batch.to(device)
loss_fun = binary_cross_entropy_with_logits
model = GNN(
num_classes,
hidden_dim,
node_features_dim,
).to(device)
out = model.forward(batch.x, batch.edge_index, batch.batch)
target = batch.y.unsqueeze(1).float()
> loss = loss_fun(out, target)
graph_classification/test_foobar.py:33:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
input = tensor([[ -5.1727, -0.4420],
[ -2.7096, 1.0948],
[ -2.0776, -0.5117],
[ -8.1472, -4.3884]...755, 9.7002],
[-14.5554, -8.8546],
[ -2.5473, -0.9506]], device='cuda:0', grad_fn=<AddmmBackward0>)
target = tensor([[0.],
[0.],
[1.],
[0.],
[0.],
[0.],
[0.],
[0.],
....],
[0.],
[0.],
[0.],
[1.],
[0.],
[0.],
[0.]], device='cuda:0')
weight = None, size_average = None, reduce = None, reduction = 'mean'
pos_weight = None
def binary_cross_entropy_with_logits(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
pos_weight: Optional[Tensor] = None,
) -> Tensor:
r"""Function that measures Binary Cross Entropy between target and input
logits.
See :class:`~torch.nn.BCEWithLogitsLoss` for details.
Args:
input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).
target: Tensor of the same shape as input with values between 0 and 1
weight (Tensor, optional): a manual rescaling weight
if provided it's repeated to match input tensor shape
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when reduce is ``False``. Default: ``True``
reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
pos_weight (Tensor, optional): a weight of positive examples.
Must be a vector with length equal to the number of classes.
Examples::
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> loss = F.binary_cross_entropy_with_logits(input, target)
>>> loss.backward()
"""
if has_torch_function_variadic(input, target, weight, pos_weight):
return handle_torch_function(
binary_cross_entropy_with_logits,
(input, target, weight, pos_weight),
input,
target,
weight=weight,
size_average=size_average,
reduce=reduce,
reduction=reduction,
pos_weight=pos_weight,
)
if size_average is not None or reduce is not None:
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
reduction_enum = _Reduction.get_enum(reduction)
if not (target.size() == input.size()):
> raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
E ValueError: Target size (torch.Size([128, 1])) must be the same as input size (torch.Size([128, 2]))
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3095: ValueError
The main issue seems to be this error raised by binary_cross_entropy_with_logits
if not (target.size() == input.size()):
> raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
E ValueError: Target size (torch.Size([128, 1])) must be the same as input size (torch.Size([128, 2]))
In other words, it looks like the issue is that target
and out
are not the same shape. However, I don't understand: out
should be of shape (128,2)
, because for each sample in the batch of size 2, I should return two logits (one for the positive class, one for the negative class). Instead, I would expect target
to be of shape (128,1)
, since each sample will have one and only one class. Clearly there's some mistake in my reasoning. Can you help me fix the error?