I would like to use pytorch to optimize a objective function which makes use of an operation that cannot be tracked by torch.autograd. I wrapped such operation with a custom forward() of the torch.autograd.Function class (as suggested here and here). Since I know the gradient of such operation, i can write also the backward(). Everything look like this:
class Projector(torch.autograd.Function):
# non_torch_var are constant values needed by the operation
@staticmethod
def forward(ctx, vertices, non_torch_var1, non_torch_var2, non_torch_var3):
ctx.save_for_backward(vertices)
vertices2=vertices.detach().cpu().numpy()
ctx.non_torch_var1 = non_torch_var1
ctx.non_torch_var2 = non_torch_var2
ctx.non_torch_var3 = non_torch_var3
out = project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
out = torch.tensor(out, requires_grad=True)
return out
@staticmethod
def backward(ctx, grad_out):
vertices = ctx.saved_tensors[0]
vertices2 = vertices.detach().cpu().numpy()
non_torch_var1 = ctx.non_torch_var1
non_torch_var2 = ctx.non_torch_var2
non_torch_var3 = ctx.non_torch_var3
grad_vertices = grad_project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
grad_vertices = torch.tensor(grad_vertices, requires_grad=True)
return grad_vertices, None, None, None
This implementation, although, seems to not work. I used the torchviz package to plot what is going on with the following lines
import torchviz
out = Projector.apply(*input)
grad_x, = torch.autograd.grad(out.sum(), vertices, create_graph=True)
torchviz.make_dot((grad_x, vertices, out), params={"grad_x": grad_x, "vertices": vertices, "out": out}).render("attached", format="png")
and I got this graph, which is showing that grad_x is not connected to anything.
Do you have an idea of what is going wrong with such a code?