1

I am trying to understand how PyTorch works and want to replicate a simple CNN training on CIFAR. The CNTK script gets to 0.76 accuracy after 168 seconds of training (10 epochs), which is similar to my MXNet script (0.75 accuracy after 153 seconds).

However, my PyTorch script is lagging behind a lot at 0.71 accuracy and 354 seconds. I appreciate I will get differences in accuracy due to stochastic weight initialisation, etc. However the difference across frameworks is much greater than difference within a framework, initialising randomly between runs.

The reasons I can think of:

  • MXNet and CNTK are initialized to xavier/glorot uniform; not sure how to do this in PyTorch and so perhaps the weights are initialised to 0
  • CNTK does gradient-clipping by default; not sure if PyTorch has the equivalent
  • Perhaps the bias is dropped in PyTorch by default
  • I use SGD with momentum; perhaps the PyTorch implementation of momentum is a bit different

Edit:

I have tried specifying the weight initialisation, however it seems to have no big effect:

self.conv1 = nn.Conv2d(3, 50, kernel_size=3, padding=1)
init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0))
init.constant(self.conv1.bias, 0)
Ilia
  • 69
  • 4

1 Answers1

1

I try to answer your first two questions:

In addition, I am curious why you don't use torchvision.transforms torch.utils.data.DataLoader and torchvision.datasets.CIFAR10 to load and preprocess your data?

There is a similar image classification tutorial of cifar for Pytorch http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py Hope this can help you.

  • Thanks a lot! I have tried playing around with weight initialisation and gradient-clipping - however still can't get the same accuracy and training-time as CNTK. In-fact when I compare it to [Chainer](https://github.com/ilkarman/Blog/blob/master/DL-Examples/Chainer_CIFAR.ipynb) the accuracy is lower by 7 ppt and it is 2 minutes slower. I don't use the default data-set and data-loader because I am creating this example across all 9 different DL frameworks and my goal is to get a series of scripts that are very similar and easy to compare (same data) - so I take the data as given/exogenous. – Ilia Aug 22 '17 at 12:39