Is it possible to add new variables/weights in custom layers during training ? I want this feature (or something similar) to implement an expandable embedding layer.
I have tried tf.py_function
, it failed to track the newly added weights, raised an exception complaining no gradients for them. tf.keras.layers.Lambda
has similar issue, it shows a warning instead of an exception:
WARNING:tensorflow:
The following Variables were used a Lambda layer's call (lambda), but
are not present in its tracked objects:
<tf.Variable 'model/boundless_embedding/block_0:0' shape=(2, 3) dtype=float32, numpy=
array([[-0.4273603 , -0.21527308, -0.3599682 ],
[-0.04728699, 0.71099687, -0.04171979]], dtype=float32)>
<tf.Variable 'model/boundless_embedding/block_1:0' shape=(2, 3) dtype=float32, numpy=
array([[ 0.44539702, 0.3135407 , -0.94582325],
[ 0.42753923, -0.15626878, 0.5873704 ]], dtype=float32)>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.
Here is my source code:
import tensorflow as tf
import numpy
class BoundlessEmbedding(tf.keras.layers.Layer):
def __init__(self, dimension, block_size=2**20):
super().__init__(dynamic=True)
self._block_size = block_size
self._dimension = dimension
self.block_weights = []
self.lookup = tf.keras.layers.Lambda(lambda x: self._lookup(x))
def call(self, x, training=None, mask=None):
if training:
with tf.init_scope():
self._maybe_expand(x)
def lookup(x_):
return self._lookup(x_)
# y = tf.py_function(lookup, [x], tf.float32)
y = self.lookup(x)
return tf.reduce_sum(y * y, axis=1)
def compute_output_shape(self, input_shape):
return input_shape + [self._dimension]
def _maybe_expand(self, x):
maximum = tf.math.reduce_max(x)
while maximum >= len(self.block_weights) * self._block_size:
id_ = len(self.block_weights)
weight = self.add_weight(
name=f'block_{id_}', dtype=tf.float32,
shape=[self._block_size, self._dimension])
self.block_weights.append(weight)
def _lookup(self, x):
# TODO Remove this reshape
x = tf.reshape(x, [-1])
valid = x < len(self.block_weights) * self._block_size
x = tf.where(valid, x, 0)
y = tf.zeros([x.shape[0], self._dimension], dtype=tf.float32)
block_ids = x // self._block_size
block_offsets = x % self._block_size
for i, weight in enumerate(self.block_weights):
idx = tf.where(tf.math.equal(block_ids, i))
offsets = tf.gather(block_offsets, tf.reshape(idx, [-1]))
values = tf.gather(weight, offsets)
y = tf.tensor_scatter_nd_update(y, idx, values)
return tf.reshape(y, [-1, self._dimension])
def build_dataset():
x = numpy.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=numpy.int32)
y = numpy.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=numpy.float32)
x = tf.data.Dataset.from_tensor_slices(x)
y = tf.data.Dataset.from_tensor_slices(y)
return tf.data.Dataset.zip((x, y))
def main():
#tf.enable_eager_execution()
embedding = BoundlessEmbedding(3, 2)
x = tf.keras.Input(name="x", shape=[None], dtype=tf.int32)
y = embedding(x)
model = tf.keras.Model(inputs=x, outputs=y)
model.compile(optimizer='sgd', loss='mse')
dataset = build_dataset()
model.fit(dataset)
if __name__ == '__main__':
main()