I am trying to implement a meta-gradient based pruning-at-initialization method by Alizadeh et al. (2022) in tensorflow
. The method works roughly like this:
- Take some batches from the dataset.
- Mask all weights of the network with ones (e. g.
tf.ones
). - Perform one update of the weights, including the mask.
- UNMASK all weights and perform the rest of the updates through the other batches.
- Compute the meta-gradient of the loss w. r. t. the mask, i. e. backpropagate through all batches and weight-updates until the mask from the first iteration is "reached".
The authors implement this in pytorch
, which I typically do not use at work. I want to implement it in tensorflow
, yet I run into the following problem: tensorflow
is not designed to process gradients "through" assign
-operations. E. g. that means:
w = tf.Variable([4.])
c = tf.Variable([2.])
with tf.GradientTape() as tape:
tape.watch(c)
w.assign(w * c)
output = 2. * w
print(output)
# >> tf.Tensor([16.], shape=(1,), dtype=float32)
print(tape.gradient(output, c))
# >> None
That being said, my "pruning loop" is looking somewhat like this:
test_factor = tf.Variable(1., dtype=tf.float32)
with tf.GradientTape(persistent=True) as outer_tape:
outer_tape.watch(masked_model.masks)
outer_tape.watch(test_factor)
## First btach
X_batch, y_batch = wrp.non_random_batch(X_train, y_train, 0, 256)
with tf.GradientTape() as tape1:
y_pred = masked_model(X_batch)
loss = test_factor*loss_fn(y_batch, y_pred)
gradients = tape1.gradient(loss, masked_model.proper_weights)
## Updating weights
for w, g in zip(masked_model.proper_weights, gradients):
w.assign(w - 0.05*g)
## Unmasking
masked_model.unmask_forward_passes()
## Second batch (and more)
X_batch, y_batch = wrp.non_random_batch(X_train, y_train, 1, 256)
with tf.GradientTape() as tape2:
y_pred = masked_model(X_batch)
loss = loss_fn(y_batch, y_pred)
gradients = tape2.gradient(loss, masked_model.proper_weights)
print(outer_tape.gradient(loss, masked_model.masks))
# >> ListWrapper([None, None, ..., None])
print(outer_tape.gradient(loss, test_factor))
# >> None
Where after the second batch more batches would be to come.
I inserted the test_factor
to show, that this problem is not some problem with my masks, but with the general structure. Simply changing the line w.assign(w - 0.05*g)
to w = w - 0.05*g
enables the usage of the gradient, but then the weights are not actually updated...
For the authors of the paper mentioned, this does not seem to be a problem. Is pytorch
simply more powerful in such cases, or do I miss some kind of trick to get this to work in tensorflow
?