Looking to calculate Mean and STD per channel over a batch efficiently.
Details:
- batch size: 128
- images: 32x32
- 3 channels (RGB)
So each batch is of size [128, 32, 32, 3].
There are lots of batches (naive method takes ~4min over all batches).
And I would like to output 2 arrays: (meanR, meanG, meanB) and (stdR, stdG, stdB)
(Also if there is an efficient way to perform arithmetic operations on the batches after calculating this, then that would be helpful. For example, subtracting the mean of the whole dataset from each image)