I implement the policy gradient method to learn the unknown function( which is a 10 loop sum function here), but the model did not update. The learning data is input and the target. func2 include the MLP model which to predict the target number. The code is the following:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import torch.optim as opt
import torch
from torch.autograd.variable import Variable
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1, 8)
self.fc2 = nn.Linear(8, 8)
self.fc3 = nn.Linear(8, 1)
def forward(self, x):
x = self.fc1(x)
x=F.relu(x)
x = self.fc2(x)
x=F.relu(x)
x=self.fc3(x)
x=F.sigmoid(x)
return x
def assertEqual(label, pre_number):
reward = -torch.sum((label - pre_number).pow(2))
return reward
def func_mulsum(model, s,scalar,n_loop):
c=s
for i in range(n_loop):
c = model(c)
return c
def train_test_data(n_loop):
nums=100
rate=0.7
train_data=np.zeros((int(nums*rate),2))
test_data=np.zeros((int(nums*(1-rate)),2))
data=random.sample(range(nums),nums)
train=data[:int(nums*rate)]
test = data[int(nums*rate):]
train_data[:,0]=train
test_data[:,0]=test
for i,ans in enumerate(train):
for j in range(n_loop):
ans+=j
train_data[i,1]=ans
for i,ans1 in enumerate(test):
for j in range(n_loop):
ans1+=j
test_data[i,1]=ans1
return train_data,test_data
if __name__ == '__main__':
n_loop=10
iterations=10
learn_data, test_case = train_test_data(n_loop)
model=MLP()
optim=opt.SGD(model.parameters(),lr=0.05)
for i in range(iterations):
reward_sum=0
learn_data = torch.FloatTensor(learn_data)
for j,data in enumerate(learn_data[:,0]):
data=data.unsqueeze(0)
label=learn_data[j,1]
optim.zero_grad()
pre=func_mulsum(model,data,255,n_loop)
p_norm = torch.normal(pre, std=0.000000001)
reward = assertEqual(p_norm,label)
# print(p_norm,label)
loss=reward*(torch.log(p_norm)*-1)
loss.backward()
reward_sum += loss
optim.step()
for para in model.parameters():
pass
print(para)
print('reward_mean............................', reward_sum)
I could not find the reason why the gradients are all 0. this question confused me for 2 days. Anyone can help me?