I figured out a way myself! The only problem I have with this solution is the way I have used to calculate the size of a minibatch, it looks pretty ugly and I don't know if there is a better way to do it.
class CorrelationMetric(keras.metrics.Metric):
def __init__(self, name="correlation", **kwargs):
super(CorrelationMetric, self).__init__(name=name, **kwargs)
self.correlation = self.add_weight(name="correlation", initializer="zeros")
self.n = self.add_weight(name="n", initializer="zeros")
self.x = self.add_weight(name="x", initializer="zeros")
self.x_squared = self.add_weight(name="x_squared", initializer="zeros")
self.y = self.add_weight(name="y", initializer="zeros")
self.y_squared = self.add_weight(name="y_squared", initializer="zeros")
self.xy = self.add_weight(name="xy", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
self.n.assign_add(tf.reduce_sum(tf.cast((y_pred == y_true), "float32")))
self.n.assign_add(tf.reduce_sum(tf.cast((y_pred != y_true), "float32")))
self.xy.assign_add(tf.reduce_sum(tf.multiply(y_pred, y_true)))
self.x.assign_add(tf.reduce_sum(y_pred))
self.y.assign_add(tf.reduce_sum(y_true))
self.x_squared.assign_add(tf.reduce_sum(tf.math.square(y_pred)))
self.y_squared.assign_add(tf.reduce_sum(tf.math.square(y_true)))
def result(self):
return (self.n * self.xy - self.x * self.y)/tf.math.sqrt((self.n * self.x_squared - tf.math.square(self.x)) * (self.n * self.y_squared - tf.math.square(self.y)))
def reset_state(self):
self.n.assign(0.0)
self.x.assign(0.0)
self.x_squared.assign(0.0)
self.y.assign(0.0)
self.y_squared.assign(0.0)
self.xy.assign(0.0)
self.correlation.assign(0.0)
#later, use this metric in a model
model.compile(X, y, ..args.., metrics=[CorrelationMetric()])