I'd like to train a small neural network in Pytorch that takes as an input an 8-dimensional vector and predicts one of three possible categories. The first hidden layer should contain 6 neurons, where each neuron takes the activations of only 3 consecutive dimensions in the input layer. The second hidden layer should also contain 6 nodes and be fully connected, and the last layer should be the output layer with 3 neurons. Thus the topology is:
Let's say that a mini batch consists of 64 (8-dimensional) data points.
I tried to implement the first layer using 1D convolution. Since a 1D convolution filter assumes the input is a sequence of points, I thought a good approach is to define 6 filters operating on 8 1-dimensional points:
import torch.nn as nn
import torch.nn.functional as functional
class ExampleNet(nn.Module):
def __init__(self, batch_size, input_channels, output_channels):
super(ExampleNet, self).__init__()
self._layer1 = nn.Conv1d(in_channels=1, out_channels=input_channels - 2, kernel_size=3, stride=1)
self._layer2 = nn.Linear(in_features=input_channels - 2, out_features=input_channels - 2)
self._layer3 = nn.Linear(in_features=input_channels - 2, out_features=output_channels)
def forward(self, x):
x = functional.relu(self._layer1(x))
x = functional.relu(self._layer2(x))
x = functional.softmax(self._layer3(x))
return x
net = ExampleNet(64, 8, 3)
I know that Pytorch expects a sequence of arrays of size 64 x 8 x 1 each when training the network. However, since I apply 1D convolutional filters in an untraditional way, I think I should have input arrays of size 64 x 1 x 8, and I am expecting an output of size 64 x 3. I use the following mini-batch of random points to run through the network:
# Generate a mini-batch of 64 samples
input = torch.randn(64, 1, 8)
out = net(input)
print(out.size())
And the output I get tells me that I defined a wrong topology. How would you advise me to define the layers I need? Is using Conv1d
a good approach in my case? I saw that another approach is to use a masked layer but I don't know how to define it.