2

I have a generic network without random element in his structure (e.g. no dropout) so that if I forward a given image input through the network, I put gradient to zero and repeat again the forward with the same image input I get the same result (same gradient vector, output,…) Now let’s say that we have a batch of N elements (data, label) and I perform the following experiment:

  1. forward the whole batch and store the gradient vector (using reduction='sum' in my criterion), use backward to generate the corresponding gradient, save it in a second object (that we’ll refer to as Batch_Grad)
output = model(data)
loss = criterion(output,torch.reshape(label, (-1,)))
loss.backward() 

Batch_Grad= []
for p in model.parameters():     
    Batch_Grad.append(p.grad.clone()) 
  1. reset the gradient
    optimizer.zero_grad()
  1. repeat the first point giving in input batch’s elements one by one and collect after each backward the corresponding element’s gradient (resetting the gradient every time after that)
    for i in range(0, len(label)):
        #repeat the procedure of point 1. for each data[i] input
        #...
        optimizer.zero_grad()
  1. Sum up togheter gradient vectors of the previous point corresponding to each element of the given batch in a single object (that we’ll refer to as Single_Grad)

  2. compare the objects of point 4. and 1. (Batch_Grad and Single_Grad)

Following the above procedure I find that tensor from point 1. and 5. are equal only if the batch size (N) is equal to 1, but they are different for N>1.

With the method of point 3. and 4. I'm manually summing gradients associated to single image propagation (which as pointed in the above comment are equals to the ones calculated automatically by SGD, with N=1). Since automatic SGD approach (point 1.)is also expected to perform the same sum: Why do I observe this difference?

user1172131
  • 103
  • 7

2 Answers2

1

The difference you are trying to work out here is between what is called a mini-batch gradient descent vs iterative updates at each training sample.

You can refer to this wiki for some background Stochastic_gradient_descent#Iterative_method

In the mini-batch method (your point 1), you update the parameters after you have calculated the loss for the whole of the batch (N). This means that you are using the same model weights for computing prediction loss for all the N samples as you wait for the next update.

In contrast to the above, for the single sample updates: you keep updating the model parameters for each sample - producing slightly different loss values. These individual differences accumulate to the difference for the entire N sized batch for your case.

jdsurya
  • 1,326
  • 8
  • 16
  • Thank you for your answer! I'm using `torch.optim.SGD`, does it implement iterative updates? – user1172131 Dec 09 '21 at 23:11
  • 1
    Yes ofc, all training normally happens by iterating over batches for each epoch, regardless of the optimizer. If you need to update weights after each sample, pick batch_size=1 (as you had already figured out). – jdsurya Dec 10 '21 at 01:00
  • On the other hand the pipeline in a Pytorch learning step is: forwarding the whole batch (without touching the weights), calculate the gradient using autograd and **finally** modify the steps. So it seems very strange to me that a torch optimizer perform an iterative update, because as said the weights are not touched during the batch forward but after that (with the `optimizer.step` call ) – user1172131 Dec 10 '21 at 08:18
0

The difference underlined in the answer of jdsurya is, in general, a good point to be aware of and to pay attention on.

On the other hand, for my specific case (question above) this is not the source of error because in Pytorch the computation of gradient and the update of weights are performed in 2 different phase (see the loss.backward() and optimizer.step() methods for more details). In both cases proposed in my experiment I only compute gradient vector, without touching weights.

As pointed here most likely the difference is due to a different order of operation because of the floating point precision.

user1172131
  • 103
  • 7