0

I found two main sources about it.

  1. A tutorial, not done according to the rule book (I prefer avoid)
  2. Keras' documentation (what I prefer to avoid surprises)

I prefer to follow Keras' documentation to avoid memory leak as it's the case for some people who try custom approaches with Keras.
But what Keras' is showing in the documentation is about classification. This is not my case.
So I tried to look at the source code of Keras. Precisely in the file: /lib/python3.7/site-packages/tensorflow_core/python/keras/metrics.py. It does not help me at all because most of metrics (some exception are classification metrics) are all done with a wrapper as the following code:

@keras_export('keras.metrics.MeanSquaredError')
class MeanSquaredError(MeanMetricWrapper):
    """Computes the mean squared error between `y_true` and `y_pred`.

    For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.]
    the mean squared error is 3/4 (0.75).

    Usage:

    ```python
    m = tf.keras.metrics.MeanSquaredError()
    m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.])
    print('Final result: ', m.result().numpy())  # Final result: 0.75
    ```

    Usage with tf.keras API:

    ```python
    model = tf.keras.Model(inputs, outputs)
    model.compile('sgd', metrics=[tf.keras.metrics.MeanSquaredError()])
    ```
    """

    def __init__(self, name='mean_squared_error', dtype=None):
        super(MeanSquaredError, self).__init__(
            mean_squared_error, name, dtype=dtype)

As you can see there's only the constructor method, no good inspiration available for the udpate_state method that I need.
Where can I find it ?


python 3.7.7
tensorflow 2.1.0
keras-applications 1.0.8
keras-preprocessing 1.1.0

AvyWam
  • 890
  • 8
  • 28

1 Answers1

0

You can use a loss function as a metric, so you can extend keras.losses.Loss instead. You only need to override call as shown in the documentation

import tensorflow as tf

class MeanSquaredError(tf.keras.losses.Loss):
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, y_pred.dtype)
        return tf.math.reduce_mean(tf.math.square(y_pred - y_true), axis=-1)
Thomas
  • 53
  • 7