3

I am experimenting with the Tensorflow model optimization library and am trying to reduce the size of the SavedModel that is running in a production cluster with the goal of reducing operating costs while keeping as much performance as possible.

A few things I've read suggested I should try out pruning weights in the model. I've tried it and so far have gotten very mixed results. Here is the code for the model I am trying to prune.

n = 300000  # input vector dimension, it's not exactly 300k but it's close
code_dimension = 512
inputs = Input(shape=(n,))
outputs = Dense(code_dimension, activation='relu')(inputs)
outputs = Dense(code_dimension, activation='relu')(outputs)
outputs = Dense(n, activation='softmax')(outputs)

model = Model(input, outputs)
model.compile("adam", "cosine_similarity")
model.fit(training_data_generator, epochs=10, validation_data=validation_data_generator)
model.save("base_model.pb")

# model pruning starts here
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(
    target_sparsity=0.95, begin_step=0, end_step=-1, frequency=100
)

callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, pruning_schedule=pruning_schedule)
model_for_pruning.compile(optimizer="adam", loss="cosine_similarity")

model_for_pruning.fit(training_data_generator, validation_data=validation_data_generator, epochs=2, callbacks=callbacks)
print(f"Mean model sparsity post-pruning: {mean_model_sparsity(model_for_pruning): .4f}")

# strip pruning not to carry around those extra parameters
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

model_for_export.save("pruned_model.pb")

Here's the problem, when I set code_dimension to 32 or 64 after model pruning and saving the pruned_model.pb file is about 2 - 3 times smaller than the base_model.pb file. However when I use a code dimension of 256 or 512 my pruned model is actually bigger than the base model.

I have a script that runs this and each time I run it I do a full reset of my environment.

Has anyone who used the TensorFlow model optimization library ever experienced this?

djvaroli
  • 1,223
  • 1
  • 11
  • 28

0 Answers0