I'm trying to train a Wave-U-Net for mixing multitrack audio (8 mono stems to a stereo mixture) following the methodology of this paper, whereby:
Each input consist of 121,843 samples or 2.76 seconds and the output corresponds to the center part of the inputs and consists of 89,093 samples or 2.02 seconds.
My net is:
class Waveunet(nn.Module):
def __init__(self):
super(Waveunet, self).__init__()
#self.enc_num_layers = 10
#self.dec_num_layers = 10
#self.enc_filter_size = 15
#self.dec_filter_size = 5
#self.input_channel = 2
#self.nfilters = 24
enc_channel_in = [8] + [min(10, (i + 1)) * 24 for i in range(9)]
enc_channel_out = [min(10, (i + 1)) * 24 for i in range(10)]
dec_channel_out = enc_channel_out[:10][::-1]
dec_channel_in = [enc_channel_out[-1]*2 + 24] + [enc_channel_out[-i-1] + dec_channel_out[i-1] for i in range(1, 10)]
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for i in range(10):
self.encoder.append(nn.Conv1d(enc_channel_in[i], enc_channel_out[i], 15))
for i in range(10):
self.decoder.append(nn.Conv1d(dec_channel_in[i], dec_channel_out[i], 5))
self.middle_layer = nn.Sequential(
nn.Conv1d(enc_channel_out[-1], enc_channel_out[-1] + 24, 15),
nn.LeakyReLU(0.2)
)
self.output_layer = nn.Sequential(
nn.Conv1d(32, 2, kernel_size=1),
nn.Tanh()
)
def forward(self,x):
encoder = list()
input = x
# Downsampling
for i in range(10):
x = self.encoder[i](x)
x = F.leaky_relu(x,0.2)
encoder.append(x)
x = x[:,:,::2]
x = self.middle_layer(x)
# Upsampling
for i in range(10):
x = F.interpolate(x, size=x.shape[-1]*2-1, mode='linear', align_corners=True)
x = self.crop_and_concat(x, encoder[10 - i - 1])
x = self.decoder[i](x)
x = F.leaky_relu(x,0.2)
# Concat with original input
x = self.crop_and_concat(x, input)
# Output prediction
output = self.output_layer(x)
return output
def crop_and_concat(self, x1, x2):
crop_x2 = self.crop(x2, x1.shape[-1])
x = torch.cat([x1,crop_x2],dim=1)
return x
def crop(self, tensor, target_shape):
# Center crop
shape = tensor.shape[-1]
diff = shape - target_shape
crop_start = diff // 2
crop_end = diff - crop_start
return tensor[:,:,crop_start:-crop_end]
Checking the summary with my input size
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Waveunet()
model = model.to(device)
from torchsummary import summary
summary(model, input_size=(8, 121843))
gives the correct output size (according to the paper):
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv1d-1 [-1, 24, 121829] 2,904
Conv1d-2 [-1, 48, 60901] 17,328
Conv1d-3 [-1, 72, 30437] 51,912
Conv1d-4 [-1, 96, 15205] 103,776
Conv1d-5 [-1, 120, 7589] 172,920
Conv1d-6 [-1, 144, 3781] 259,344
Conv1d-7 [-1, 168, 1877] 363,048
Conv1d-8 [-1, 192, 925] 484,032
Conv1d-9 [-1, 216, 449] 622,296
Conv1d-10 [-1, 240, 211] 777,840
Conv1d-11 [-1, 264, 92] 950,664
LeakyReLU-12 [-1, 264, 92] 0
Conv1d-13 [-1, 240, 179] 605,040
Conv1d-14 [-1, 216, 353] 492,696
Conv1d-15 [-1, 192, 701] 391,872
Conv1d-16 [-1, 168, 1397] 302,568
Conv1d-17 [-1, 144, 2789] 224,784
Conv1d-18 [-1, 120, 5573] 158,520
Conv1d-19 [-1, 96, 11141] 103,776
Conv1d-20 [-1, 72, 22277] 60,552
Conv1d-21 [-1, 48, 44549] 28,848
Conv1d-22 [-1, 24, 89093] 8,664
Conv1d-23 [-1, 2, 89093] 66
Tanh-24 [-1, 2, 89093] 0
================================================================
Total params: 6,183,450
Trainable params: 6,183,450
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.72
Forward/backward pass size (MB): 156.46
Params size (MB): 23.59
Estimated Total Size (MB): 183.77
----------------------------------------------------------------
However on training with Adam and L1 loss I get the following broadcasting error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [75], in <cell line: 17>()
13 optimiser = torch.optim.Adam(net.parameters(),
14 lr=0.001)
16 # train model
---> 17 train(net, train_loader, loss_fn, optimiser, device, 10)
Input In [72], in train(model, data_loader, loss_fn, optimiser, device, epochs)
20 for i in range(epochs):
21 print(f"Epoch {i+1}")
---> 22 train_single_epoch(model, data_loader, loss_fn, optimiser, device)
23 for *vinputs, vtarget in data_loader:
24 *vinputs, vtarget = vinputs[0].to(device), vinputs[1].to(device), vinputs[2].to(device), vinputs[3].to(device), vinputs[4].to(device), vinputs[5].to(device), vinputs[6].to(device), vinputs[7].to(device), vtarget.to(device)
Input In [72], in train_single_epoch(model, data_loader, loss_fn, optimiser, device)
7 # calculate loss
8 prediction = model(cat)
----> 9 loss = loss_fn(prediction, target)
11 # backpropagate error and update weights
12 optimiser.zero_grad()
File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\modules\loss.py:96, in L1Loss.forward(self, input, target)
95 def forward(self, input: Tensor, target: Tensor) -> Tensor:
---> 96 return F.l1_loss(input, target, reduction=self.reduction)
File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\functional.py:3248, in l1_loss(input, target, size_average, reduce, reduction)
3245 if size_average is not None or reduce is not None:
3246 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3248 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3249 return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\functional.py:73, in broadcast_tensors(*tensors)
71 if has_torch_function(tensors):
72 return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 73 return _VF.broadcast_tensors(tensors)
RuntimeError: The size of tensor a (89093) must match the size of tensor b (121843) at non-singleton dimension 2