I want to prune over the highest weight values in a tf layer. I'm thinking about using tf.nn.top_k
but I'm not exactly sure how I would go about doing this.
Documentation: https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude Code:
pruning_params = {
'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
final_sparsity=0.8, begin_step=1000, end_step=2000),
'block_size': (2, 3),
'block_pooling_type': 'MAX'
}
model = keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
])