I am currently using SageMaker to train BERT and trying to improve the BERT training time. I use PyTorch and Huggingface on AWS g4dn.12xlarge instance type.
However when I run parallel training it is far from achieving linear improvement. I'm looking for some hints on distributed training to improve the BERT training time in SageMaker.