I have recently started working with pytorch and I noticed that I was not getting repeatable/deterministic results when evaluating a pre-trained model on new inputs.
I have boiled the problem down to this minimum example which shows that repeatedly applying the same simple convolution model does not yield identical results:
import numpy as np
import matplotlib.pyplot as plt
import torch
device = torch.device('cpu')
# function to get all the params from a pytorch model
def getParams(model):
a = list(model.parameters())
b = [a[i].detach().cpu().numpy() for i in range(len(a))]
c = [b[i].flatten() for i in range(len(b))]
d = np.hstack(c)
return d
# set up a simple model (9 params)
testModule = torch.nn.Conv2d(1, 1, kernel_size = (3, 3), bias = False, stride = 1, padding = 1).double()
torch.nn.init.normal_(testModule.weight, mean=0, std=1)
testModule = testModule.eval()
# set up a dummy input
patch = torch.from_numpy(np.random.randn(1,1,80,80).astype('double')).to(device)
# apply the model 100 times
testVals = []
testParams = []
testModuleOut = []
for ii in range(100):
testParams.append(getParams(testModule))
testModuleOut.append(testModule(patch).cpu().detach()[0,:,:,:].numpy())
testParams = np.stack(testParams)
testModuleOut = np.stack(testModuleOut)
# view the variation of the model parameters and the output values
plt.figure()
plt.plot(np.std(testParams,axis=0))
plt.xlabel('Parameter index')
plt.ylabel('Standard deviation over runs')
plt.figure()
plt.plot(np.std(testModuleOut,axis=0).ravel())
plt.xlabel('Output index')
plt.ylabel('Standard deviation over runs')
If re-running the network were repeatable, I would expect the standard deviation graphs to show flat lines at SD = 0. But I do not get this, instead I get some random looking graph lines, that change with every run of the script (sometimes the module parameters have SD = 0, but the network output never seems to).
What is the problem with my code? The SDs seem to be around machine precision but why would repeatedly pulling params from the module cause them to change in this way? Would we not just be pulling the exact same value from memory?