I am using the following snippet for implementing convolution function in CNN.
def conv(x, in_channels, out_channels, kernel_size, stride, padding, weight, bias):
"""
Args:
x: torch tensor with size (N, C_in, H_in, W_in),
in_channels: number of channels in the input image, it is C_in;
out_channels: number of channels produced by the convolution;
kernel_size: size of onvolving kernel,
stride: stride of the convolution,
padding: implicit zero padding to be added on both sides of each dimension,
Return:
y: torch tensor of size (N, C_out, H_out, W_out)
"""
y = None
xKernShape = kernel_size
yKernShape = kernel_size
xImgShape = x.shape[2]
yImgShape = x.shape[3]
xOutput = int(((xImgShape - xKernShape + 2 * padding) / stride) + 1)
yOutput = int(((yImgShape - yKernShape + 2 * padding) / stride) + 1)
output = np.zeros((xOutput, yOutput))
if padding != 0:
imagePadded = np.zeros((x.shape[2] + padding*2, x.shape[3] + padding*2))
imagePadded[int(padding):int(-1 * padding), int(padding):int(-1 * padding)] = x
print(imagePadded)
else:
imagePadded = x
for i in range(x.shape[3]):
if i > x.shape[3] - yKernShape:
break
if i % stride == 0:
for j in range(x.shape[2]):
if j > x.shape[2] - xKernShape:
break
try:
if j % stride == 0:
output[j, i] = (kernel_size * imagePadded[j: j + xKernShape, i: i + yKernShape]).sum()
y = np.array(np.hsplit(output, 1)).reshape((x.shape[0], out_channels, j, i))
except:
break
return y
However, when I run this using call conv(x,in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0,weight=torch_conv.weight,bias=torch_conv.bias)
I am getting
AttributeError Traceback (most recent call last)
<ipython-input-71-f4d6ee7b5fd8> in <module>
6 padding=0,
7 weight=torch_conv.weight,
----> 8 bias=torch_conv.bias)
<ipython-input-70-1d0a9f0d9820> in my_conv(x, in_channels, out_channels, kernel_size, stride, padding, weight, bias)
33 if i % stride == 0:
34 for x in range(x.shape[2]):
---> 35 if x > x.shape[2] - xKernShape:
36 break
37 try:
AttributeError: 'int' object has no attribute 'shape'
My input is 2X3X32X32.
Can anyone please help?