I have a matrix A of dimension 1000x70000. my loss function includes A and I want to find optimal value of A using gradient descent where the constraint is that the rows of A remain in probability simplex (i.e. every row sums up to 1). I have initialised A as given below
A=np.random.dirichlet(np.ones(70000),1000)
A=torch.tensor(A,requires_grad=True)
and my training loop looks like as given below
for epoch in range(500):
y_pred=forward(X)
y=model(torch.mm(A.float(),X))
l=loss(y,y_pred)
l.backward()
A.grad.data=-A.grad.data
optimizer.step()
optimizer.zero_grad()
if epoch%2==0:
print("Loss",l,"\n")