Trying to implement the reaserch paper: https://ieeexplore.ieee.org/document/9479786/ Training a Monotone Network with architechture:
class Model(nn.Module):
def __init__(self, q, s):
self.layer_s_list = [nn.Linear(5, s) for _ in range(q)]
self.inv_w, self.inv_b = self.get_layer_weights()
def forward(self, x):
# print(inv_w[0].shape, inv_b[0].shape)
output_lst = []
for layer in self.layer_s_list:
v, id = torch.max(layer(x), 1)
output_lst.append(v.detach().numpy())
output_lst = np.array(output_lst)
output_lst = torch.from_numpy(output_lst)
out, _ = torch.min(output_lst, 0)
allo_out = F.softmax(out)
pay_out = nn.ReLU(inplace = True)(out)
inv_out_lst = []
for q_idx in range(len(self.inv_w)):
# print(inv_w[q_idx].shape, pay_out.shape, inv_b[q_idx].shape)
y, _ = torch.min(torch.linalg.pinv(self.inv_w[q_idx]) * (pay_out - self.inv_b[q_idx]), 0)
inv_out_lst.append(y.detach().numpy())
final_out = np.array(inv_out_lst)
final_out = torch.from_numpy(final_out)
final_out, _ = torch.max(final_out, 1)
return final_out, allo_out
def get_layer_weights(self):
weights_lst = []
bias_lst = []
for layer in self.layer_s_list:
weights_lst.append(layer.state_dict()['weight'])
bias_lst.append(layer.state_dict()['bias'])
return weights_lst, bias_lst
When I initialise the network and run for random inputs:
q = 5
s = 10
x = torch.rand((10, 5), requires_grad = True)
net = Model(q, s)
y, z = net(x)`
It gives the following error:
AttributeError Traceback (most recent call last)
<ipython-input-3-aac6d239df1f> in <module>
1 x = torch.rand((10, 5), requires_grad = True)
2 net = Model(5, 10)
----> 3 y = net(x)
1 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in __getattr__(self, name)
1206 return modules[name]
1207 raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1208 type(self).__name__, name))
1209
1210 def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
AttributeError: 'Model' object has no attribute '_backward_hooks'
Please help me understand what this error is and how to fix it.