I'm working through the PyTorch tutorial on Defining new autograd functions. The autograd function I want to implement is a wrapper around torch.nn.functional.max_pool1d
. Here is what I have so far:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as tag
class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, \
return_indices=False, ceil_mode=False):
ctx.save_for_backward( input )
inputC = input.clone() #copy input
inputC *= inputC
output = F.max_pool1d(inputC, kernel_size, stride=stride, \
padding=padding, dilation=dilation, \
return_indices=return_indices, \
ceil_mode=ceil_mode)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = get_max_pool1d_grad_somehow(grad_output)
return 2.0*input*grad_input
My question is: how to I get the gradient of the wrapped function? I know that there are probably other ways to do this given how simple the example I present is, but what I want to do fits this framework and requires me to implement an autograd
function.
Edit: After examining this blog post I decided to try the following for backward
:
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = output.backward(grad_output)
return 2.0*input*grad_input
with output
added to the saved variables. I then run the following code:
x = np.random.randn(1,1,5)
xT = torch.from_numpy(x)
xT.requires_grad=True
f = SquareAndMaxPool1d.apply
s = torch.sum(f(xT,2))
s.backward()
and I get Bus error: 10
.
Say, xT
is tensor([[[ 1.69533562, -0.21779421, 2.28693953, -0.86688095, -1.01033497]]], dtype=torch.float64)
, then I would expect to find that xT.grad
is tensor([[[ 3.39067124, -0. , 9.14775812, -0. , -2.02066994]]], dtype=torch.float64)
after calling s.backward()
(that is 2*x*grad_of_max_pool
, with grad_of_max_pool
containing tensor([[[1., 0., 2., 0., 1.]]], dtype=torch.float64)
).
I've figured out why I get a Bus error: 10
. It appears that the above code leads to a recursive call of my backward
at grad_input = output.backward(grad_output)
. So I need to find some other way to get the gradient of max_pool1d
. I know how to implement this in pure Python, but the result would be much slower than if I could wrap the library code.