acc_loss = model.train_on_batch([support, smask, query], qmask)
This is a part of code that I want to convert to pytorch but I'm stuck to find out if there's any Pytorch's function as train_on_batch()
or any custom function in Pytorch that can serve the purpose.
support,smask,query,qmask ## these are all rank-4 tensors having (batch_size,channel,height,width)
The model is a custom vgg16 encoder and a decoder for segmentation.