2

I define an embedding object with embeddings_constraint:

from tensorflow.keras.layers import Embedding
from tensorflow.keras.constraints import UnitNorm
. . .
emb = Embedding(input_dim, output_dim, name='embedding_name', embeddings_constraint=UnitNorm(axis=1))
. . .

Later on in the code, when I want to train my model which contains emb, I get an exception from the function model.fit:

RuntimeError: Cannot use a constraint function on a sparse variable.

When I don't impose the embeddings constraint on emb, however, the code does not throw an error. Moreover, I tried this with TF 1, and there it worked fine as well (with and without embeddings_constraint). According to a GitHub discussion, this appears to be TF 2 bug, though no working solution is proposed.

Any ideas how to solve this?

Belphegor
  • 4,456
  • 11
  • 34
  • 59

2 Answers2

2

A workaround for this issue is to call directly the constraint like this:

from tensorflow.keras.layers import Embedding
from tensorflow.keras.constraints import UnitNorm
. . .
emb = Embedding(input_dim, output_dim, name='embedding_name')
norm_layer = UnitNorm(axis=1)
norm_embedding = norm_layer(emb(embedding_id_input))
. . .
0

I hit this also, added to the tensorflow github issue, tried it in tf 2.7 and tf-nightly, still saw bug, so made new keras issue: https://github.com/keras-team/keras/issues/15818.

The workaround won't address all use cases - it constrains the embedding activity as opposed to the entire embedding matrix. I do have a workaround using keras callbacks - but I found it slowed down training - I believe the callback is working with numpy arrays that don't get to live on the GPU - but if useful, here is the approach

import numpy as np
import tensorflow as tf


class ConstrainEmbeddings(tf.keras.callbacks.Callback):
    def __init__(self, min_norm, emb, eps=1e-9):
        super(ConstrainEmbeddings, self).__init__()
        self.min_norm = min_norm
        self.emb = emb
        self.eps = eps

    def on_batch_begin(self, *args, **kwargs):
        W = self.emb.get_weights()[0]
        norms = tf.maximum(self.eps, tf.norm(W, axis=1))
        delta = tf.expand_dims(tf.math.divide(self.min_norm, norms) - 1.0, 1)
        deltaW = tf.math.multiply(W, delta)
        constrainedW = W + tf.expand_dims(tf.cast(norms < self.min_norm, dtype=tf.float32), 1) * deltaW
        self.emb.set_weights([constrainedW])


def constrain_embeddings(use_keras):
    N = 10
    batch_size = 5
    data = {
        'X': np.random.randint(0, 10, N),
        'Y': np.random.randint(0, 2, N)
    }

    def get_labels(features):
        labels = features.pop('Y')
        return features, labels

    dset = tf.data.Dataset.from_tensor_slices(data).map(get_labels).batch(batch_size)

    inp = tf.keras.Input(shape=(1,), name='X', dtype='int64')
    constraint = tf.keras.constraints.MaxNorm(max_value=0.1) if use_keras else None

    emb = tf.keras.layers.Embedding(
        10, 3, input_length=1,
        embeddings_initializer=tf.keras.initializers.RandomUniform(minval=-1.0, maxval=1.0),
        embeddings_constraint=constraint
    )
    emb_out = emb(inp)
    out = tf.keras.layers.Dense(1)(emb_out)
    model = tf.keras.Model(inputs=inp, outputs=out)
    model.compile(optimizer='adam', loss=tf.keras.losses.binary_crossentropy)
    callbacks = [ConstrainEmbeddings(.3, emb)] if not use_keras else []
    model.fit(dset, epochs=10, callbacks=callbacks)


if __name__ == '__main__':
    print("tensorflow git version: ", tf.version.GIT_VERSION)
    constrain_embeddings(use_keras=False)
MrCartoonology
  • 1,997
  • 4
  • 22
  • 38