I have a pruned TF model which I need to retrain with the streaming data. Model includes an embedding layer (10 nodes), GRU layer (16 notes) and classification layer (85 nodes). I want to retrain only the non-zero weights of the 80% pruned model. I'd like to avoid creating a mask and performing additional calculations as I want to minimize retraining time. Here's the training function that I'm currently using.
@tf.function
def train_step(inputs, targets):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
loss_value = loss_fn(targets, predictions)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
return loss_value
Is there any way we can calculate the grads only of the non-zero weights or use only the non-zero weights in model.trainable_weights
?
If not, is there any way to use tf.IndexedSlices to update non-zero weights efficiently?
I tried the below code as well (instead of the grads in the above code). But it does not work.
grads_and_vars = [(tf.IndexedSlices(g.values * tf.cast(tf.math.greater(v, 0),
tf.float32), g.indices), v) for g, v in zip(grads, model_retrained.trainable_weights)]
optimizer.apply_gradients(grads_and_vars)
Highly appreciate any support on this