0

I'm working on a PyTorch project and I want to generate MNIST images using a U-Net architecture combined with a DDPM (Diffusion Models) approach. I'm encountering the following error: encountering the following error:

  File "C:\Users\zzzz\miniconda3\envs\ddpm2\Lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (35840x28 and 10x10)

This error is happening in the context of a self-attention mechanism within my U-NET. Here's the relevant part of the code: Model.py:

SelfAttention Class:
class SelfAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(SelfAttention, self).__init__()

        print("in_dim:",in_dim)
        print("out_dim:",out_dim)
        self.query = nn.Linear(in_dim, out_dim)
        self.key = nn.Linear(in_dim, out_dim)
        self.value = nn.Linear(in_dim, out_dim)
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # Calculate query, key, and value projections'
        print("x_query:",x.shape)
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        # Calculate scaled dot-product attention scores
        print("query:",query.shape)
        print("key:",key.shape)
        print("value:",value.shape)
        print("key.transpose(-2, -1):",key.transpose(-2, -1).shape)
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(key.size(-1))
        
        # Apply softmax to get attention weights
        attention_weights = self.softmax(scores)
        
        # Calculate the weighted sum of values
        print("attention_weights:",attention_weights.shape)
        print("value:",value.shape)
        output = torch.matmul(attention_weights, value)
        
        return output

class MyBlockWithAttention(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(MyBlockWithAttention, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.attention = SelfAttention(out_c, out_c)  # Add self-attention here
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        print("before:",out.shape)
        out = self.attention(out)  # Apply self-attention
        print("after:",out.shape)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out

U-NET:

class MyUNet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100, in_c=1):
        super(MyUNet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlockWithAttention((1, 28, 28), 1, 10),
            MyBlockWithAttention((10, 28, 28), 10, 10),
            MyBlockWithAttention((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)
        
        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            MyBlockWithAttention((10, 14, 14), 10, 20),
            MyBlockWithAttention((20, 14, 14), 20, 20),
            MyBlockWithAttention((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            MyBlockWithAttention((20, 7, 7), 20, 40),
            MyBlockWithAttention((40, 7, 7), 40, 40),
            MyBlockWithAttention((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            MyBlockWithAttention((40, 3, 3), 40, 20),
            MyBlockWithAttention((20, 3, 3), 20, 20),
            MyBlockWithAttention((20, 3, 3), 20, 40)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            MyBlockWithAttention((80, 7, 7), 80, 40),
            MyBlockWithAttention((40, 7, 7), 40, 20),
            MyBlockWithAttention((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            MyBlockWithAttention((40, 14, 14), 40, 20),
            MyBlockWithAttention((20, 14, 14), 20, 10),
            MyBlockWithAttention((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            MyBlockWithAttention((20, 28, 28), 20, 10),
            MyBlockWithAttention((10, 28, 28), 10, 10),
            MyBlockWithAttention((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        print("before reshape t:", self.te1(t).shape)
        print("this is x:",x.shape)
        print("this is on yeki:", self.te1(t).reshape(n, -1, 1, 1).shape)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )
Zahra Hosseini
  • 478
  • 2
  • 4
  • 14

0 Answers0