0

Essentially, I want to perform pruning to my transfer learning model.

I used efficientnetb0 for classifying microorganisms.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 32
epochs = 25

end_step = np.ceil(len(training_set) / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
                                        initial_sparsity = 0.40,                                                                 
                                        final_sparsity = 0.90,                                                                   
                                        begin_step = 0,                                                                
                                        end_step = end_step
                                        )
                  }

model_for_pruning = prune_low_magnitude(
                         efficientnetb0_transfer_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
efficientnetb0_transfer_model_for_pruning.compile(optimizer=optim,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

efficientnetb0_transfer_model_for_pruning.summary()

However, I'm getting the following error:

ValueError: Please initialize `Prune` with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a `PrunableLayer` instance, or should has a customer defined `get_prunable_weights` method. You passed: <class 'tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling'>

What could I be doing wrong?

Kaveh
  • 4,618
  • 2
  • 20
  • 33
djbacs
  • 41
  • 5

1 Answers1

1

You're hitting this error.

The pruning API is not flexible enough. It currently expects all layers in the model to be prunable (logic here). Ideally it should be able to skip layers like image rescaling. Can you file a github issue and we'll work on a fix. Thanks!

Yunlu Li
  • 77
  • 2