19

I am using the sequential model in Keras. I would like to check the weight of the model after every epoch. Could you please guide me on how to do so.

model = Sequential()
model.add(Embedding(max_features, 128, dropout=0.2))
model.add(LSTM(128, dropout_W=0.2, dropout_U=0.2))  
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',optimizer='adam',metrics['accuracy'])
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=5 validation_data=(X_test, y_test))

Thanks in advance.

Kiran Baktha
  • 627
  • 2
  • 9
  • 20

1 Answers1

18

What you are looking for is a CallBack function. A callback is a Keras function which is called repetitively during the training at key points. It can be after a batch, an epoch or the whole training. See here for doc and the list of callbacks existing.

What you want is a custom CallBack that can be created with a LambdaCallBack object.

from keras.callbacks import LambdaCallback

model = Sequential()
model.add(Embedding(max_features, 128, dropout=0.2))
model.add(LSTM(128, dropout_W=0.2, dropout_U=0.2))  
model.add(Dense(1))
model.add(Activation('sigmoid'))

print_weights = LambdaCallback(on_epoch_end=lambda batch, logs: print(model.layers[0].get_weights()))

model.compile(loss='binary_crossentropy',optimizer='adam',metrics['accuracy'])
model.fit(X_train, 
          y_train, 
          batch_size=batch_size, 
          nb_epoch=5 validation_data=(X_test, y_test), 
          callbacks = [print_weights])

the code above should print your embedding weights model.layers[0].get_weights() at the end of every epoch. Up to you to print it where you want to make it readable, to dump it into a pickle file,...

Hope this helps

Nassim Ben
  • 11,473
  • 1
  • 34
  • 52
  • Thank you for your answer, but if I want to save all the weights in to list rather than print it out, how can I do that? I have tried logs["weights"].append(model.layers[0].get_weights() but it's not work – jimmy15923 Jun 05 '17 at 05:10
  • 1
    @jimmy15923 model.layers[0].get_weights() only shows the first layer weight, which would be nothing considering it's for the input. You need to iterate through all the layers. – Andy Wei May 22 '18 at 01:27
  • 1
    typo: you should print the epoch, not the batch: print_weights = LambdaCallback(on_epoch_end=lambda _epoch_, ... – MAltakrori Oct 04 '18 at 14:24