1

In k-fold cross validation why we need to reset the weights after each fold we use thia function

def reset_weights(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): m.reset_parameters() so we reset the weights of the model so that each cross-validation fold starts from some random initial state and not learning from the previous folds.

Why i that important ? and i think if we don't do that it would be better that the model learn from all folds and update its parameter from all of them not every one on its own.

adel_hany1
  • 11
  • 1
  • 1
    you want each fold to be independent of each other, so you get an idea of how consistent is your algorithm when the training set changes. If you allow folds to influence each other you loose this robustness information. You can make a system that trains with different training sets and learns from each other, but then it's not crossvalidation what you are doing, it's just training. Crossvalidation as the name suggest aims to validate your model by using different train/validation combinations –  Mar 25 '22 at 15:55

1 Answers1

0

Cross fold validation is meant to validate if the model performance is consistent and robust to different subsample of train and test data, and to fine tuning hyper parameters in a less biased way.

If the model have a good performance, with low variance among numerous (usually 5 or 10) folds of train and test data, it means that the model performance is not related to some subsample of the data.

https://en.wikipedia.org/wiki/Cross-validation_(statistics)

After validade the model, you can train it on the whole dataset, without splitting it, to improve performance.

But this approach alone can't tell if your model has overfitted or not, so take note of CNN regularization and validation methods.

https://www.analyticsvidhya.com/blog/2020/09/overfitting-in-cnn-show-to-treat-overfitting-in-convolutional-neural-networks/

Guinther Kovalski
  • 1,629
  • 1
  • 7
  • 15