2

I want to prune over the highest weight values in a tf layer. I'm thinking about using tf.nn.top_k but I'm not exactly sure how I would go about doing this.

Documentation: https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude Code:

pruning_params = {
    'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
        final_sparsity=0.8, begin_step=1000, end_step=2000),
    'block_size': (2, 3),
    'block_pooling_type': 'MAX'
}

model = keras.Sequential([
    layers.Dense(10, activation='relu', input_shape=(100,)),
    prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
])

iiooii
  • 551
  • 2
  • 9
  • 15

1 Answers1

1

Assuming that w is the weight matrix of the layer you want to prune, and k is the percentage of weights that should be pruned, this should do the trick for you:

# Convert k from percentage to integer representing the number of weights
k = tf.cast(tf.round(tf.size(w, out_type=tf.float32) * tf.constant(k)), dtype=tf.int32)
# Reshape flatten the weight matrix
w_reshaped = tf.reshape(w, [-1])
# Select the indices of the largest k weights
_, indices = tf.nn.top_k(w_reshaped, k, sorted=True, name=None)
# Set the elements matching the indices to 0
mask = tf.scatter_nd_update(tf.Variable(tf.ones_like(w_reshaped, dtype=tf.float32), name="mask", trainable=False), tf.reshape(indices, [-1, 1]), tf.zeros([k], tf.float32))
# Update the weight matrix w
w.assign(tf.reshape(w_reshaped * mask, tf.shape(w)))

This is based on this Github repo. Please note that in that project, I am pruning the smallest k weights.

gorjan
  • 5,405
  • 2
  • 20
  • 40
  • 1
    Thanks for the help - I appreciate it – iiooii Aug 07 '19 at 14:42
  • How do you suggest I modify mask to then prune the largest k weights. I've been thinking of using something like if not in indices, then set to 0. – iiooii Aug 07 '19 at 21:29
  • Just use it as it is in the code snippet above. In the example I modified the code from the Github repo to fit your purpose. – gorjan Aug 07 '19 at 21:46