2
 acc_loss = model.train_on_batch([support, smask, query], qmask)

This is a part of code that I want to convert to pytorch but I'm stuck to find out if there's any Pytorch's function as train_on_batch() or any custom function in Pytorch that can serve the purpose.

support,smask,query,qmask ## these are all rank-4 tensors having (batch_size,channel,height,width)

The model is a custom vgg16 encoder and a decoder for segmentation.

Codess
  • 21
  • 2

0 Answers0