I am using Swish activation function, with trainable parameter according to the paper SWISH: A Self-Gated Activation Function paper by Prajit Ramachandran, Barret Zoph and Quoc V. Le. I am using LeNet-5 CNN as a toy example on MNIST to train 'beta' instead of using beta = 1 as present in nn.SiLU(). I am using PyTorch 2.0 and Python 3.10. The example code is:
class LeNet5(nn.Module):
def __init__(self, beta = 1.0):
super(LeNet5, self).__init__()
b = torch.tensor(data = beta, dtype = torch.float32)
self.beta = torch.autograd.Variable(b, requires_grad = True)
self.conv1 = nn.Conv2d(
in_channels = 1, out_channels = 6,
kernel_size = 5, stride = 1,
padding = 0, bias = False
)
self.bn1 = nn.BatchNorm2d(num_features = 6)
self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
self.conv2 = nn.Conv2d(
in_channels = 6, out_channels = 16,
kernel_size = 5, stride = 1,
padding = 0, bias = False
)
self.bn2 = nn.BatchNorm2d(num_features = 16)
self.fc1 = nn.Linear(
in_features = 256, out_features = 120,
bias = True
)
self.bn3 = nn.BatchNorm1d(num_features = 120)
self.fc2 = nn.Linear(
in_features = 120, out_features = 84,
bias = True
)
self.bn4 = nn.BatchNorm1d(num_features = 84)
self.fc3 = nn.Linear(
in_features = 84, out_features = 10,
bias = True
)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
# print(m)
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
# Do not initialize bias (due to batchnorm)-
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
# Standard initialization for batch normalization-
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def swish_fn(self, x):
return x * torch.sigmoid(x * self.beta)
def forward(self, x):
'''
x = nn.SiLU()(self.pool1(self.bn1(self.conv1(x))))
x = nn.SiLU()(self.pool1(self.bn2(self.conv2(x))))
x = x.view(-1, 256)
x = nn.SiLU()(self.bn3(self.fc1(x)))
x = nn.SiLU()(self.bn4(self.fc2(x)))
'''
x = self.pool(self.bn1(self.conv1(x)))
x = self.swish_fn(x = x)
x = self.pool(self.bn2(self.conv2(x)))
x = self.swish_fn(x = x)
x = x.view(-1, 256)
x = self.bn3(self.fc1(x))
x = self.swish_fn(x = x)
x = self.bn4(self.fc2(x))
x = self.swish_fn(x = x)
x = self.fc3(x)
return x
While training the model, I am printing 'beta' as:
for epoch in range(1, num_epochs + 1):
# One epoch of training-
train_loss, train_acc = train_one_step(
model = model, train_loader = train_loader,
train_dataset = train_dataset
)
# Get validation metrics after 1 epoch of training-
val_loss, val_acc = test_one_step(
model = model, test_loader = test_loader,
test_dataset = test_dataset
)
scheduler.step()
current_lr = optimizer.param_groups[0]["lr"]
print(f"Epoch: {epoch}; loss = {train_loss:.4f}, acc = {train_acc:.2f}%",
f" val loss = {val_loss:.4f}, val acc = {val_acc:.2f}%,"
f" beta = {model.beta:.6f} & LR = {current_lr:.5f}"
)
# Save training metrics to Python3 dict-
train_history[epoch] = {
'train_loss': train_loss, 'val_loss': val_loss,
'train_acc': train_acc, 'val_acc': val_acc,
'lr': current_lr
}
# Save model with best validation accuracy-
if (val_acc > best_val_acc):
best_val_acc = val_acc
print(f"Saving model with highest val_acc = {val_acc:.2f}%\n")
torch.save(model.state_dict(), "LeNet5_MNIST_best_val_acc.pth")
What am I doing wrong? Why isn't beta training as expected?