I hit this also, added to the tensorflow github issue, tried it in tf 2.7 and tf-nightly, still saw bug, so made new keras issue: https://github.com/keras-team/keras/issues/15818.
The workaround won't address all use cases - it constrains the embedding activity as opposed to the entire embedding matrix. I do have a workaround using keras callbacks - but I found it slowed down training - I believe the callback is working with numpy arrays that don't get to live on the GPU - but if useful, here is the approach
import numpy as np
import tensorflow as tf
class ConstrainEmbeddings(tf.keras.callbacks.Callback):
def __init__(self, min_norm, emb, eps=1e-9):
super(ConstrainEmbeddings, self).__init__()
self.min_norm = min_norm
self.emb = emb
self.eps = eps
def on_batch_begin(self, *args, **kwargs):
W = self.emb.get_weights()[0]
norms = tf.maximum(self.eps, tf.norm(W, axis=1))
delta = tf.expand_dims(tf.math.divide(self.min_norm, norms) - 1.0, 1)
deltaW = tf.math.multiply(W, delta)
constrainedW = W + tf.expand_dims(tf.cast(norms < self.min_norm, dtype=tf.float32), 1) * deltaW
self.emb.set_weights([constrainedW])
def constrain_embeddings(use_keras):
N = 10
batch_size = 5
data = {
'X': np.random.randint(0, 10, N),
'Y': np.random.randint(0, 2, N)
}
def get_labels(features):
labels = features.pop('Y')
return features, labels
dset = tf.data.Dataset.from_tensor_slices(data).map(get_labels).batch(batch_size)
inp = tf.keras.Input(shape=(1,), name='X', dtype='int64')
constraint = tf.keras.constraints.MaxNorm(max_value=0.1) if use_keras else None
emb = tf.keras.layers.Embedding(
10, 3, input_length=1,
embeddings_initializer=tf.keras.initializers.RandomUniform(minval=-1.0, maxval=1.0),
embeddings_constraint=constraint
)
emb_out = emb(inp)
out = tf.keras.layers.Dense(1)(emb_out)
model = tf.keras.Model(inputs=inp, outputs=out)
model.compile(optimizer='adam', loss=tf.keras.losses.binary_crossentropy)
callbacks = [ConstrainEmbeddings(.3, emb)] if not use_keras else []
model.fit(dset, epochs=10, callbacks=callbacks)
if __name__ == '__main__':
print("tensorflow git version: ", tf.version.GIT_VERSION)
constrain_embeddings(use_keras=False)