OK, turns out the error was to do with several things:
dZ
needed to be dilated relative to the stride in the forward propagation
- the window function used for windowing
dZ
(done after dilation of dZ
) needed to be called with stride 1 (no matter the stride choice in the forward propagation) with the output heights and widths of the padded input (not the original, unpadded input -- this was the main mistake that took me days to debug)
the relevant code is below with comments explaining shapes and operations as well as some further sources for reading. i've also included the forward propagation.
i should note that after days of debugging, writing various functions, reading etc. the variable names changed after a while, so for the ease of reading, here are the names of the variables as defined in my question and then their equivalent in the code below:
A_prev
is x
dZ
is dout_descendant
Hout
is the height of dout_descendant
Wout
is the width of dout_descendant
(as one would expect all references to self
are to the class these functions are part of)
def _pad(self, array, pad_size, pad_val):
'''
only symmetric padding is implemented
'''
return np.pad(array, ((0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)), 'constant', constant_values=(pad_val, pad_val))
def _dilate(self, array, stride_size, pad_size, symmetric_filter_shape, output_image_size):
# on dilation for backprop with stride>1,
# see: https://medium.com/@mayank.utexas/backpropagation-for-convolution-with-strides-8137e4fc2710
# see also: https://leimao.github.io/blog/Transposed-Convolution-As-Convolution/
pad_bottom_and_right = (output_image_size + 2 * pad_size - symmetric_filter_shape) % stride_size
for m in range(stride_size - 1):
array = np.insert(array, range(1, array.shape[1], m + 1), 0, axis=1)
array = np.insert(array, range(1, array.shape[2], m + 1), 0, axis=2)
for _ in range(pad_bottom_and_right):
array = np.insert(array, array.shape[1], 0, axis=1)
array = np.insert(array, array.shape[2], 0, axis=2)
return array
def _windows(self, array, stride_size, filter_shapes, out_height, out_width):
'''
inputs:
array to create windows of
stride_size: int
filter_shapes: tuple(int): tuple of filter height and width
out_height and out_width: int, respectively: output sizes for the windows
returns:
windows of array with shape (excl. dilation):
array.shape[0], out_height, out_width, filter_shapes[0], filter_shapes[1], array.shape[3]
'''
strides = (array.strides[0], array.strides[1] * stride_size, array.strides[2] * stride_size, array.strides[1], array.strides[2], array.strides[3])
return np.lib.stride_tricks.as_strided(array, shape=(array.shape[0], out_height, out_width, filter_shapes[0], filter_shapes[1], array.shape[3]), strides=strides, writeable=False)
def forward(self, x):
'''
expects inputs to be of shape: [batchsize, height, width, channel in]
after init, filter_shapes are: [fh, fw, channel in, channel out]
'''
self.input_shape = x.shape
x_pad = self._pad(x, self.pad_size, self.pad_val)
self.input_pad_shape = x_pad.shape
# get the shapes
batch_size, h, w, Cin = self.input_shape
# calculate output sizes; only symmetric padding is possible
self.Hout = (h + 2*self.pad_size - self.fh) // self.stride + 1
self.Wout = (w + 2*self.pad_size - self.fw) // self.stride + 1
x_windows = self._windows(array=x_pad, stride_size=self.stride, filter_shapes=(self.fh, self.fw),
out_width=self.Wout, out_height=self.Hout) # 2D matrix with shape (batch_size, Hout, Wout, fh, fw, Cin)
self.out = np.tensordot(x_windows, self.w, axes=([3,4,5], [0,1,2])) + self.b
self.inputs = x_windows
## alternative 1: einsum approach, slower than other alternatives
# self.out = np.einsum('noufvc,fvck->nouk', x_windows, self.w) + self.b
## alternative 2: column approach with simple dot product
# z = x_windows.reshape(-1, self.fh * self.fw * Cin) @ self.W.reshape(self.fh*self.fw*Cin, Cout) + self.b # 2D matrix with shape (batch_size * Hout * Wout, Cout)
# self.dout = z.reshape(batch_size, Hout, Wout, Cout)
def backward(self,dout_descendant):
'''
dout_descendant has shape (batch_size, Hout, Wout, Cout)
'''
# get shapes
batch_size, h, w, Cin = self.input_shape
# we want to sum everything but the filters for b
self.db = np.sum(dout_descendant, axis=(0,1,2), keepdims=True) # shape (1,1,1, Cout)
# for dW we'll use the column approach with ordinary dot product for variety ;) tensordot does the same without all the reshaping
dout_descendant_flat = dout_descendant.reshape(-1, self.Cout) # new shape (batch_size * Hout * Wout, Cout)
x_flat = self.inputs.reshape(-1, self.fh * self.fw * Cin) # shape (batch_size * Hout * Wout, fh * fw * Cin)
dw = x_flat.T @ dout_descendant_flat # shape (fh * fw * Cin, Cout)
self.dw = dw.reshape(self.fh, self.fw, Cin, self.Cout)
del dout_descendant_flat # free memory
# for dinputs: we'll get padded and dilated windows of dout_descendant and perform the tensordot with 180 rotated W
# for details, see https://medium.com/@mayank.utexas/backpropagation-for-convolution-with-strides-8137e4fc2710 ; also: https://pavisj.medium.com/convolutions-and-backpropagations-46026a8f5d2c ; also: https://youtu.be/Lakz2MoHy6o?t=835
Wrot180 = np.rot90(self.w, 2, axes=(0,1)) # or also self.w[::-1, ::-1, :, :]
# backprop for forward with stride > 1 is done on windowed dout that's padded and dilated with stride 1
dout_descendant = self._dilate(dout_descendant, stride_size=self.stride, pad_size=self.pad_size, symmetric_filter_shape=self.fh, output_image_size=h)
dout_descendant = self._pad(dout_descendant, pad_size=self.fw-1, pad_val=self.pad_val) # pad dout_descendant to dim: fh-1 (or fw-1); only symmetrical filters are supported
dout_descendant = self._windows(array=dout_descendant, stride_size=1, filter_shapes=(self.fh, self.fw),
out_height=h + 2 * self.pad_size, out_width=w + 2 * self.pad_size) # shape: (batch_size * h_padded * w_padded, fh * fw * Cout)
self.dout = np.tensordot(dout_descendant, Wrot180, axes=([3,4,5],[0,1,3]))
self.dout = self.dout[:,self.pad_size:-self.pad_size, self.pad_size:-self.pad_size, :]
## einsum alternative, but slower:
# dinput = np.einsum('nhwfvk,fvck->nhwc', dout_windows, self.W)
i've left this answer here, because all the other sources on stackoverflow or github i could find that used numpy stride tricks were implemented for convolutions of stride 1 (which doesn't require dilation of dZ
) or they used very complex fancy indexing operations that were extremely hard to follow (e.g. https://sgugger.github.io/convolution-in-depth.html#convolution-in-depth or https://github.com/parasdahal/deepnet/blob/51a9e61c351138b7dc637f4b748a0e6ca2e15595/deepnet/im2col.py)