0

I'm working with activity recognition models implemented in GluonCV/MxNet, and I wanted to try replacing the fully connected layer of the model by using global average pooling.

I have 4 output classes, so the idea was to:

  • remove the current model head;
  • add a convolutional layer with 4 filters of dimensions (1x1x1), in order to obtain 4 output feature maps;
  • run global average pooling on each of them, so to have a final 4-dim vector;

I tried to modify the init method in environment_folder/lib/python3.8/site-packages/gluoncv/model_zoo/action_recognition/i3d_resnet.py...currently, the model head with the FC layer is defined this way:

self.feat_dim = self.block.expansion * 64 * 2**(len(self.stage_blocks) - 1)

 # We use ``GlobalAvgPool3D`` here for simplicity. Otherwise the input size must be fixed.
# You can also use ``AvgPool3D`` and specify the arguments on your own, e.g.
# self.st_avg = nn.AvgPool3D(pool_size=(4, 7, 7), strides=1, padding=0)
# ``AvgPool3D`` is 10% faster, but ``GlobalAvgPool3D`` makes the code cleaner.
self.st_avg = nn.GlobalAvgPool3D()

self.head = nn.HybridSequential(prefix='')
self.head.add(nn.Dropout(rate=self.dropout_ratio))
self.fc = nn.Dense(in_units=self.feat_dim, units=nclass, weight_initializer=init.Normal(sigma=self.init_std))
self.head.add(self.fc)

So, I removed the last 4 lines, and tried to replace them with:

self.head = nn.HybridSequential(prefix='')
self.head.add(nn.Dropout(rate=self.dropout_ratio))
self.head.add(nn.Conv3D(4, 1, strides=(1, 1, 1), padding=(0, 0, 0)))
self.head.add(nn.GlobalAvgPool3D())

However, I'm getting the error:

File "/Users/cdemasi/opt/anaconda3/envs/bshot/lib/python3.8/site-packages/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "../src/operator/tensor/./broadcast_reduce_op.h", line 401
MXNetError: Check failed: ishape.ndim() == param.shape.ndim() (3 vs. 5) : Operand of shape [1000,2048,1] cannot be broadcasted to [0,0,1,0,0]

Any suggestions on how to fix this?

Carlo
  • 1,321
  • 12
  • 37

0 Answers0