Suppose in PyTorch I have model1
and model2
which have the same architecture. They were further trained on same data or one model is an earlier version of the othter, but it is not technically relevant for the question. Now I want to set the weights of model
to be the average of the weights of model1
and model2
. How would I do that in PyTorch?
Asked
Active
Viewed 6,110 times
12

double-beep
- 5,031
- 17
- 33
- 41

patapouf_ai
- 17,605
- 13
- 92
- 132
-
Why would you do that? The mean of weights doesn't mean anything at all. – Dr. Snoopy Feb 01 '18 at 10:20
-
1For example I could want to do Polyakov averaging. – patapouf_ai Feb 01 '18 at 10:25
-
Whatever transformation you want to do on the weights won't produce any meaningful value that will have high accuracy or low loss. – Dr. Snoopy Feb 01 '18 at 10:27
-
As mentioned, I doubt its worth but see if [this](https://discuss.pytorch.org/t/copy-weights-only-from-a-networks-parameters/5841/2) is of any help. You could grab the parameters, transform and load them back but make sure the dimensions match. – Littleone Feb 01 '18 at 13:22
-
@Littleone thank you! I will try that :) – patapouf_ai Feb 01 '18 at 14:56
-
@MatiasValdenegro you should check out Polyakov averaging, it is a well established technique in deep learning, furthermore interpolating between minimums is a well established way of visualizing the loss landscapes, there are tons of usages for this. But in any case that is irrelevant to the question. – patapouf_ai Feb 01 '18 at 14:57
-
@patapouf_ai Could you provide a source? Googling doesn't show many relations with Deep Learning. – Dr. Snoopy Feb 01 '18 at 15:52
-
@Littleone Thanks! It works :) If you would like to provide an answer, I will accept it. – patapouf_ai Feb 01 '18 at 15:56
-
@MatiasValdenegro see for example: https://arxiv.org/abs/1412.6651 or for visualization purposes: https://openreview.net/forum?id=HkmaTz-0W¬eId=HkmaTz-0W – patapouf_ai Feb 01 '18 at 15:58
-
What **kind** of model do you have? For instance, if you have a typical CNN, then averaging the weights is highly counterproductive: the filters are almost guaranteed to train to different purposes, such that averaging would give you a *bad* result. – Prune Feb 01 '18 at 20:01
-
2@Dr.Snoopy `"The mean of weights doesn't mean anything at all."` This is **blatantly wrong**. Federated learning uses averaging. And so do soft updates of target networks in virtually all SOTA DeepRL algorithms. Stochastic Weight Averaging is also a technique that may improve convergence, and in addition it provides another way for Bayesian deep learning. Not to mention countless other cases were averages of weights resemble familiar deep learning structures (e.g. ensembles, dropouts, regularization) – kyriakosSt Dec 04 '22 at 19:09
-
@kyriakosSt It was true when I wrote that comment, and also I think you are confusing things, averating for many distributed algorithms is gradient averaging, not weight averaging (A3C for example), while ensembles or dropout do not average weights, they average activations, there is a huge difference there. – Dr. Snoopy Dec 04 '22 at 20:17
-
1@Dr.Snoopy I am not referring to gradient averaging (like A3C). Every family I mentioned above uses weight averages. For the last two cases, I mentioned specifically that weight averages "resemble" those operations. In fact, through arguments like the ones presented in the Bayesian MC Dropout paper, it is easy to show that "activation averaging" and weight averaging are very closely related. Regardless of that though, weight averaging is still a reasonable operation to investigate, as it was in 2018, when a number of the above domains were already discovered. – kyriakosSt Dec 04 '22 at 21:42
-
@kyriakosSt I do research on BDL, I know all of these techniques. SWA is a Bayesian NN, not a standard NN. Target networks in DQN do not use weight averaging, so I am not sure which exact technique you are talking about. I am very doubtful since you did not point to a specific algorithm clearly. – Dr. Snoopy Dec 04 '22 at 21:49
-
1@Dr.Snoopy The initial vanilla DQN did not include soft target updates. [DDPG did](https://arxiv.org/pdf/1509.02971.pdf) and numerous implementations of DQN do ever since [(1)](https://unnatsingh.medium.com/deep-q-network-with-pytorch-d1ca6f40bfda), [(2)](https://greentec.github.io/reinforcement-learning-third-en/#fn:2), as well as follow up papers [(3)](https://deepai.org/publication/t-soft-update-of-target-network-for-deep-reinforcement-learning). Since you work on BDL, I don't see why we are still arguing about whether weight averaging is meaningful or nor. – kyriakosSt Dec 04 '22 at 22:20
-
@kyriakosSt Because you make implicit claims, DDPG uses tau=0.001, and you call this the same as weight averaging? It is not, the small value of tau should tell you something, weight averaging would be tau=0.5 (a value 500 times larger). – Dr. Snoopy Dec 04 '22 at 22:27
1 Answers
14
beta = 0.5 #The interpolation parameter
params1 = model1.named_parameters()
params2 = model2.named_parameters()
dict_params2 = dict(params2)
for name1, param1 in params1:
if name1 in dict_params2:
dict_params2[name1].data.copy_(beta*param1.data + (1-beta)*dict_params2[name1].data)
model.load_state_dict(dict_params2)
Taken from pytorch forums. You could grab the parameters, transform and load them back but make sure the dimensions match.
Also I would be really interested in knowing about your findings with these..

patapouf_ai
- 17,605
- 13
- 92
- 132

Littleone
- 641
- 6
- 14
-
2Thank you :) ! Typically in stackoverflow, when you link to an outside source, you also want to recopy the relevant information in your answer because the link might eventually become a deadlink or the information there might change. I upvoted, but if you could provide a complete answer by re-copying the relevant parts of the page you link to I will be able to accept the answer. – patapouf_ai Feb 02 '18 at 07:25
-
2