2

I followed this tutorial and tried to modify it a little bit to see if I understand things correctly. However, when I tried to use torch.opim.SGD

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0")

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
w1 = torch.nn.Parameter(torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True))
w2 = torch.nn.Parameter(torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True))
lr = 1e-6
optimizer=torch.optim.SGD([w1,w2],lr=lr)
for t in range(500):
    layer_1 = x.matmul(w1)
    layer_1 = F.relu(layer_1)
    y_pred = layer_1.matmul(w2)
    loss = (y_pred - y).pow(2).sum()
    print(t,loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

, my loss blows up to Inf at the third iteration and to nan afterwards, which is completely different compared to updating it manually. The code for updating it manually is below (also in the tutorial link).

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)


w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    loss.backward()

    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        w1.grad.zero_()
        w2.grad.zero_()

I wonder what is wrong with my modified version (the first snippet). When I replaced SGD with Adam, the results came out pretty nice (decreasing after each iteration, no Inf or nan).

casual-coder
  • 33
  • 2
  • 5
  • You omitted `clamp(min=0)` from the modified version, may be important. Otherwise, the modification should work the same as the manual version. – Sergii Dymchenko May 08 '19 at 22:13
  • Thank you very much for your comment. I thought clamp(min=0) would behave similarly to F.relu(). – casual-coder May 08 '19 at 22:31
  • I think you're right, `clamp(min=0)` should be equal to relu. I run your modified code several times (with added `dtype = torch.float` and `N, D_in, H, D_out = 64, 1000, 100, 10` lines), it worked OK and every time I got very small loss in the end. – Sergii Dymchenko May 08 '19 at 23:21

0 Answers0