10

I am trying to do research on batch normalization, and had to make some modifications for the pytorch BN code. I dig into the pytorch code and got stuck with torch.nn.functional.batch_norm, which references torch.batch_norm.

The problem is that torch.batch_norm cannot be further found in the torch library. Is there any way I can find the source code of this built-in function and re-implement it? Thanks!

StoneFree
  • 133
  • 1
  • 8

1 Answers1

12

It's there, but it's not defined in Python. They're defined in C++ in the aten/ directories.

For CPU, the implementation (one of them, it depends on whether or not the input is contiguous) is here: https://github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization.cpp#L61-L126

For CUDA, the implementation is here: https://github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143

JoshVarty
  • 9,066
  • 4
  • 52
  • 80
  • Thanks, it perfectly solved my problem. Though I'm still curious how did you precisely knew where the code is. Is there like a mapping between them? – StoneFree Oct 03 '19 at 01:14
  • 4
    There's no mapping I know of. I just searched the repository on GitHub for `batch_norm` and looked through the results manually. I kept an eye out for the `/aten` folder since that's where most of C++ source lives. – JoshVarty Oct 03 '19 at 02:17