I'm working on a PyTorch project and I want to generate MNIST images using a U-Net architecture combined with a DDPM (Diffusion Models) approach. I'm encountering the following error: encountering the following error:
File "C:\Users\zzzz\miniconda3\envs\ddpm2\Lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (35840x28 and 10x10)
This error is happening in the context of a self-attention mechanism within my U-NET. Here's the relevant part of the code: Model.py:
SelfAttention Class:
class SelfAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super(SelfAttention, self).__init__()
print("in_dim:",in_dim)
print("out_dim:",out_dim)
self.query = nn.Linear(in_dim, out_dim)
self.key = nn.Linear(in_dim, out_dim)
self.value = nn.Linear(in_dim, out_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
# Calculate query, key, and value projections'
print("x_query:",x.shape)
query = self.query(x)
key = self.key(x)
value = self.value(x)
# Calculate scaled dot-product attention scores
print("query:",query.shape)
print("key:",key.shape)
print("value:",value.shape)
print("key.transpose(-2, -1):",key.transpose(-2, -1).shape)
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(key.size(-1))
# Apply softmax to get attention weights
attention_weights = self.softmax(scores)
# Calculate the weighted sum of values
print("attention_weights:",attention_weights.shape)
print("value:",value.shape)
output = torch.matmul(attention_weights, value)
return output
class MyBlockWithAttention(nn.Module):
def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
super(MyBlockWithAttention, self).__init__()
self.ln = nn.LayerNorm(shape)
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
self.attention = SelfAttention(out_c, out_c) # Add self-attention here
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
self.activation = nn.SiLU() if activation is None else activation
self.normalize = normalize
def forward(self, x):
out = self.ln(x) if self.normalize else x
out = self.conv1(out)
print("before:",out.shape)
out = self.attention(out) # Apply self-attention
print("after:",out.shape)
out = self.activation(out)
out = self.conv2(out)
out = self.activation(out)
return out
U-NET:
class MyUNet(nn.Module):
def __init__(self, n_steps=1000, time_emb_dim=100, in_c=1):
super(MyUNet, self).__init__()
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
# First half
self.te1 = self._make_te(time_emb_dim, 1)
self.b1 = nn.Sequential(
MyBlockWithAttention((1, 28, 28), 1, 10),
MyBlockWithAttention((10, 28, 28), 10, 10),
MyBlockWithAttention((10, 28, 28), 10, 10)
)
self.down1 = nn.Conv2d(10, 10, 4, 2, 1)
self.te2 = self._make_te(time_emb_dim, 10)
self.b2 = nn.Sequential(
MyBlockWithAttention((10, 14, 14), 10, 20),
MyBlockWithAttention((20, 14, 14), 20, 20),
MyBlockWithAttention((20, 14, 14), 20, 20)
)
self.down2 = nn.Conv2d(20, 20, 4, 2, 1)
self.te3 = self._make_te(time_emb_dim, 20)
self.b3 = nn.Sequential(
MyBlockWithAttention((20, 7, 7), 20, 40),
MyBlockWithAttention((40, 7, 7), 40, 40),
MyBlockWithAttention((40, 7, 7), 40, 40)
)
self.down3 = nn.Sequential(
nn.Conv2d(40, 40, 2, 1),
nn.SiLU(),
nn.Conv2d(40, 40, 4, 2, 1)
)
# Bottleneck
self.te_mid = self._make_te(time_emb_dim, 40)
self.b_mid = nn.Sequential(
MyBlockWithAttention((40, 3, 3), 40, 20),
MyBlockWithAttention((20, 3, 3), 20, 20),
MyBlockWithAttention((20, 3, 3), 20, 40)
)
# Second half
self.up1 = nn.Sequential(
nn.ConvTranspose2d(40, 40, 4, 2, 1),
nn.SiLU(),
nn.ConvTranspose2d(40, 40, 2, 1)
)
self.te4 = self._make_te(time_emb_dim, 80)
self.b4 = nn.Sequential(
MyBlockWithAttention((80, 7, 7), 80, 40),
MyBlockWithAttention((40, 7, 7), 40, 20),
MyBlockWithAttention((20, 7, 7), 20, 20)
)
self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
self.te5 = self._make_te(time_emb_dim, 40)
self.b5 = nn.Sequential(
MyBlockWithAttention((40, 14, 14), 40, 20),
MyBlockWithAttention((20, 14, 14), 20, 10),
MyBlockWithAttention((10, 14, 14), 10, 10)
)
self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
self.te_out = self._make_te(time_emb_dim, 20)
self.b_out = nn.Sequential(
MyBlockWithAttention((20, 28, 28), 20, 10),
MyBlockWithAttention((10, 28, 28), 10, 10),
MyBlockWithAttention((10, 28, 28), 10, 10, normalize=False)
)
self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)
def forward(self, x, t):
# x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
t = self.time_embed(t)
n = len(x)
print("before reshape t:", self.te1(t).shape)
print("this is x:",x.shape)
print("this is on yeki:", self.te1(t).reshape(n, -1, 1, 1).shape)
out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1)) # (N, 10, 28, 28)
out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1)) # (N, 20, 14, 14)
out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1)) # (N, 40, 7, 7)
out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1)) # (N, 40, 3, 3)
out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1)) # (N, 20, 7, 7)
out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1)) # (N, 10, 14, 14)
out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1)) # (N, 1, 28, 28)
out = self.conv_out(out)
return out
def _make_te(self, dim_in, dim_out):
return nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.SiLU(),
nn.Linear(dim_out, dim_out)
)