4
>     import tensorflow as tf
>     
>     class MyMetric(tf.keras.callbacks.Callback):
>        def on_epoch_end(self,epoch,logs={}):
>            # how to access X_train and X_val here
> 
>     ...
>     model.fit(X_train,y_train,batch_size=32,epochs=10,validation_data=(X_val,y_val),shuffle=True,callbacks=[MyMetric()]

I am trying to implement a custom metric in tensorflow 2.0 using a callback. Within the on_epoch_end method I need to access the training and validation data (the entire samples, not batches) as provided to the fit method. Is there any way to do this? Thanks!

tim
  • 879
  • 1
  • 8
  • 26
tudor.a
  • 49
  • 3
  • Not easily. Batch input and outputs are provided by parameters on the on_batch_begin and end callbacks. You could store them manually from there? – gerwin Dec 03 '19 at 22:41

2 Answers2

7

Accept training and test dataset as initialisation argument to your custom callback class and then use it in your on_epoch_end method.

Something like this

class MyMetric(keras.callbacks.Callback):

  def __init__(self, X_test):
    self.X_test = X_test

And while calling fit, pass test set as argument to your custom callback as below

model.fit(X_train,y_train,batch_size=32,epochs=10,validation_data=(X_val,y_val),shuffle=True,callbacks=[MyMetric(X_test)]

More details on https://keras.io/guides/writing_your_own_callbacks/

Gaurav Desai
  • 71
  • 1
  • 3
1

You can edit the .fit function and pass in an extra list or queue, then pass the extra argument into the callback function... Probably a queue, then have another thread or function process the queue.

I did a similar modification to the Paramiko library and it worked well

ThatCampbellKid
  • 561
  • 1
  • 5
  • 19