0

I have a pruned TF model which I need to retrain with the streaming data. Model includes an embedding layer (10 nodes), GRU layer (16 notes) and classification layer (85 nodes). I want to retrain only the non-zero weights of the 80% pruned model. I'd like to avoid creating a mask and performing additional calculations as I want to minimize retraining time. Here's the training function that I'm currently using.

@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
       predictions = model(inputs, training=True)
       loss_value = loss_fn(targets, predictions)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss_value

Is there any way we can calculate the grads only of the non-zero weights or use only the non-zero weights in model.trainable_weights? If not, is there any way to use tf.IndexedSlices to update non-zero weights efficiently?

I tried the below code as well (instead of the grads in the above code). But it does not work.

grads_and_vars = [(tf.IndexedSlices(g.values * tf.cast(tf.math.greater(v, 0), 
tf.float32), g.indices), v) for g, v in zip(grads, model_retrained.trainable_weights)]
optimizer.apply_gradients(grads_and_vars)

Highly appreciate any support on this

Sampath Rajapaksha
  • 111
  • 1
  • 1
  • 11

0 Answers0