0

I am trying to implement a meta-gradient based pruning-at-initialization method by Alizadeh et al. (2022) in tensorflow. The method works roughly like this:

  1. Take some batches from the dataset.
  2. Mask all weights of the network with ones (e. g. tf.ones).
  3. Perform one update of the weights, including the mask.
  4. UNMASK all weights and perform the rest of the updates through the other batches.
  5. Compute the meta-gradient of the loss w. r. t. the mask, i. e. backpropagate through all batches and weight-updates until the mask from the first iteration is "reached".

The authors implement this in pytorch, which I typically do not use at work. I want to implement it in tensorflow, yet I run into the following problem: tensorflow is not designed to process gradients "through" assign-operations. E. g. that means:

w = tf.Variable([4.])
c = tf.Variable([2.])

with tf.GradientTape() as tape:
    tape.watch(c)
    w.assign(w * c)
    output = 2. * w

print(output) 
# >> tf.Tensor([16.], shape=(1,), dtype=float32)

print(tape.gradient(output, c))  
# >> None

That being said, my "pruning loop" is looking somewhat like this:

test_factor = tf.Variable(1., dtype=tf.float32)
with tf.GradientTape(persistent=True) as outer_tape:
    outer_tape.watch(masked_model.masks)
    outer_tape.watch(test_factor)

    ## First btach
    X_batch, y_batch = wrp.non_random_batch(X_train, y_train, 0, 256)
    with tf.GradientTape() as tape1:
        y_pred = masked_model(X_batch)
        loss = test_factor*loss_fn(y_batch, y_pred)
    gradients = tape1.gradient(loss, masked_model.proper_weights)

    ## Updating weights
    for w, g in zip(masked_model.proper_weights, gradients):
        w.assign(w - 0.05*g)

    ## Unmasking
    masked_model.unmask_forward_passes()

    ## Second batch (and more)
    X_batch, y_batch = wrp.non_random_batch(X_train, y_train, 1, 256)
    with tf.GradientTape() as tape2:
        y_pred = masked_model(X_batch)
        loss = loss_fn(y_batch, y_pred)
    gradients = tape2.gradient(loss, masked_model.proper_weights)

print(outer_tape.gradient(loss, masked_model.masks)) 
# >> ListWrapper([None, None, ..., None])

print(outer_tape.gradient(loss, test_factor)) 
# >> None

Where after the second batch more batches would be to come. I inserted the test_factor to show, that this problem is not some problem with my masks, but with the general structure. Simply changing the line w.assign(w - 0.05*g) to w = w - 0.05*g enables the usage of the gradient, but then the weights are not actually updated...

For the authors of the paper mentioned, this does not seem to be a problem. Is pytorch simply more powerful in such cases, or do I miss some kind of trick to get this to work in tensorflow?

Jannis
  • 1

0 Answers0