1

Is it possible to add new variables/weights in custom layers during training ? I want this feature (or something similar) to implement an expandable embedding layer.

I have tried tf.py_function, it failed to track the newly added weights, raised an exception complaining no gradients for them. tf.keras.layers.Lambda has similar issue, it shows a warning instead of an exception:

WARNING:tensorflow:
The following Variables were used a Lambda layer's call (lambda), but
are not present in its tracked objects:
  <tf.Variable 'model/boundless_embedding/block_0:0' shape=(2, 3) dtype=float32, numpy=
array([[-0.4273603 , -0.21527308, -0.3599682 ],
       [-0.04728699,  0.71099687, -0.04171979]], dtype=float32)>
  <tf.Variable 'model/boundless_embedding/block_1:0' shape=(2, 3) dtype=float32, numpy=
array([[ 0.44539702,  0.3135407 , -0.94582325],
       [ 0.42753923, -0.15626878,  0.5873704 ]], dtype=float32)>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.

Here is my source code:

import tensorflow as tf
import numpy


class BoundlessEmbedding(tf.keras.layers.Layer):
    def __init__(self, dimension, block_size=2**20):
        super().__init__(dynamic=True)

        self._block_size = block_size
        self._dimension = dimension
        self.block_weights = []
        self.lookup = tf.keras.layers.Lambda(lambda x: self._lookup(x))

    def call(self, x, training=None, mask=None):
        if training:
            with tf.init_scope():
                self._maybe_expand(x)

        def lookup(x_):
            return self._lookup(x_)

        # y = tf.py_function(lookup, [x], tf.float32)
        y = self.lookup(x)

        return tf.reduce_sum(y * y, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape + [self._dimension]

    def _maybe_expand(self, x):
        maximum = tf.math.reduce_max(x)
        while maximum >= len(self.block_weights) * self._block_size:
            id_ = len(self.block_weights)
            weight = self.add_weight(
                    name=f'block_{id_}', dtype=tf.float32,
                    shape=[self._block_size, self._dimension])
            self.block_weights.append(weight)

    def _lookup(self, x):
        # TODO Remove this reshape
        x = tf.reshape(x, [-1])

        valid = x < len(self.block_weights) * self._block_size
        x = tf.where(valid, x, 0)

        y = tf.zeros([x.shape[0], self._dimension], dtype=tf.float32)
        block_ids = x // self._block_size
        block_offsets = x % self._block_size

        for i, weight in enumerate(self.block_weights):
            idx = tf.where(tf.math.equal(block_ids, i))
            offsets = tf.gather(block_offsets, tf.reshape(idx, [-1]))
            values = tf.gather(weight, offsets)
            y = tf.tensor_scatter_nd_update(y, idx, values)
        return tf.reshape(y, [-1, self._dimension])


def build_dataset():
    x = numpy.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=numpy.int32)
    y = numpy.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=numpy.float32)
    x = tf.data.Dataset.from_tensor_slices(x)
    y = tf.data.Dataset.from_tensor_slices(y)
    return tf.data.Dataset.zip((x, y))


def main():
    #tf.enable_eager_execution()

    embedding = BoundlessEmbedding(3, 2)
    x = tf.keras.Input(name="x", shape=[None], dtype=tf.int32)
    y = embedding(x)

    model = tf.keras.Model(inputs=x, outputs=y)
    model.compile(optimizer='sgd', loss='mse')

    dataset = build_dataset()
    model.fit(dataset)


if __name__ == '__main__':
    main()
user416983
  • 974
  • 3
  • 18
  • 28

1 Answers1

1

Remove the lambda layer and define the lookup operation within the _lookup method itself. Also handle the weights expansion directly within the _maybe_expand method. Tried to give this a go, hope it helps:

class BoundlessEmbedding(tf.keras.layers.Layer):
    def __init__(self, dimension, block_size=2**20):
        super().__init__(dynamic=True)

        self._block_size = block_size
        self._dimension = dimension
        self.block_weights = []
        
    def build(self, input_shape):
        initial_weight = self.add_weight(
            name=f'block_0', dtype=tf.float32,
            shape=[self._block_size, self._dimension])
        self.block_weights.append(initial_weight)
        
        super().build(input_shape)

    def call(self, x, training=None, mask=None):
        if training:
            self._maybe_expand(x)

        return self._lookup(x)

    def _maybe_expand(self, x):
        maximum = tf.math.reduce_max(x)
        while maximum >= len(self.block_weights) * self._block_size:
            id_ = len(self.block_weights)
            weight = self.add_weight(
                    name=f'block_{id_}', dtype=tf.float32,
                    shape=[self._block_size, self._dimension])
            self.block_weights.append(weight)

    def _lookup(self, x):
        x = tf.reshape(x, [-1])
        y = tf.zeros([x.shape[0], self._dimension], dtype=tf.float32)

        for i, weight in enumerate(self.block_weights):
            mask = tf.math.equal(x // self._block_size, i)
            offsets = x % self._block_size
            values = tf.gather(weight, offsets)
            updates = tf.reshape(values * tf.expand_dims(tf.cast(mask, tf.float32), axis=1), [-1, self._dimension])
            y += updates
        
        return tf.reshape(y, [-1, self._dimension])

MaikeruDev
  • 1,075
  • 1
  • 19