1

I am trying to define a method in Python that I want to use as a Metric, especially for EarlyStopping (restore_best_weights). The Problem is that I'm trying to make a prediction in this method (using the current parameters) which doesn't seem to work. (I need the prediction for a specific recursive problem)

Please see the following simplified example:

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.activations import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.initializers import *
from tensorflow.keras.callbacks import *
import numpy as np

x_train = np.zeros((100, 7))
y_train = np.zeros(100)

model = Sequential()
model.add(Dense(units=7, input_shape=(x_train.shape[1], )))
model.add(Dense(units=1))
model.add(Activation('sigmoid'))

input1 = np.zeros((5, 7), dtype=np.float32)
y_hat = model.predict(input1)
print(y_hat)

def testMetric(y_true, y_pred):
    #input1 = np.zeros((5, 7), dtype=np.float32)
    #y_hat = model.predict(input1)
    return 5

model.compile(
loss="binary_crossentropy",
optimizer=Adam(0.05),
metrics=['binary_accuracy', testMetric]
)

reduce_lr = ReduceLROnPlateau(monitor='testMetric', min_delta=0, factor=0.7, patience=1, verbose=1, mode='max')
early = EarlyStopping(monitor='testMetric', min_delta=0, patience=7, verbose=1, mode='max', baseline=None, restore_best_weights=True)
model.fit(
    x=x_train,
    y=y_train,
    epochs=3,
    callbacks=[early, reduce_lr]
    )

Everything goes fine when I don't use the prediction in my method "testMetric". But when I use the prediciton (uncomment), I got an error message.

RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()

Is it possible to use the prediction in my method?

I'm using Tensorflow 2.2.0

That would be helpful to me :)

SajanGohil
  • 960
  • 13
  • 26
1994
  • 71
  • 8
  • Can you also post the error message, I tried this on colab and got error related to replica context (which I think is related to distributed strategy). – SajanGohil Jul 11 '20 at 12:30
  • When I use the prediction in my method, I got the following error: "RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()". I don't know what to do about that. – 1994 Jul 11 '20 at 13:18

0 Answers0