I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights(). The callback operation containing these calls takes 0.01ms compared to the 0.004ms taken to run the batch and as such more than doubles the training time required. I assume that this operation is simply moving tensors around (should be only on GPU) and thus should not take time comparable to the large matrix multiplications occurring during the batch iteration. Does anyone have any idea why this happens, or any approaches to reduce the time taken to call set_weights() and get_weights()?
Code is below:
### PRUNE WEIGHTS CALLBACK ###
class pruneModelCallback(Callback):
def __init__(self, init_weight_dict=None, mask_dict=None):
self.n_batches = 0
self.init_weight_dict = init_weight_dict
self.mask_dict = mask_dict
def on_train_batch_begin(self, batch, logs=None):
# save weights at initialization
if self.n_batches == 0:
if self.init_weight_dict is not None:
for layer_i in range(len(self.model.layers)):
w = self.init_weight_dict['w_'+str(layer_i+1)]
b = self.init_weight_dict['b_'+str(layer_i+1)]
self.model.layers[layer_i].set_weights([w,b])
else:
self.init_weight_dict = {}
for layer_i in range(len(self.model.layers)):
w = self.model.layers[layer_i].get_weights()[0]
b = self.model.layers[layer_i].get_weights()[1]
self.init_weight_dict['w_'+str(layer_i+1)] = w
self.init_weight_dict['b_'+str(layer_i+1)] = b
self.n_batches = self.n_batches + 1
# This is the problematic function, runs every training iteration batch
def on_train_batch_end(self, batch, logs=None):
# zero out pruned weights
if self.mask_dict is not None:
for layer_i in range(len(self.model.layers)):
# removing these slightly improves runtime
w = self.model.layers[layer_i].get_weights()[0]
b = self.model.layers[layer_i].get_weights()[1]
w_mask = self.mask_dict['w_'+str(layer_i+1)]
# this multiplication takes no time comparably and removing it
# does not influence time taken
w_pruned = w * w_mask
# removing this function call significantly speeds up the runtime
self.model.layers[layer_i].set_weights([w_pruned,b])
The warning generated:
1/629 [..............................] - ETA: 0s - loss: 2.3211 - accuracy: 0.0781
WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0040s vs `on_train_batch_end` time: 0.0100s). Check your callbacks.
A brief description of what's going on in the setup:
The setup is an experimental reproduction of iterative magnitude pruning described in Frankle and Carbin's Lottery Ticket Hypothesis Paper. I am attempting to reproduce pruning in a way that is easily implemented by using a pruning mask that is element-wise multiplied by the weights every iteration to reverse the effect of the weight update, thus keeping pruned weights zero. Hence, the need for get_weights() and set_weights() every training iteration.
The model in question is a standard DNN with fully connected layers throughout. No special layers like batch norm or effects like dropout or other regularization:
model = Sequential([
Dense(300, input_dim=input_dim[0], activation='relu'),
Dense(100, activation='relu'),
Dense(50, activation='relu'),
Dense(output_dim[0], activation='softmax')
])
model.compile(
optimizer = keras.optimizers.Adam(lr=1.2e-4),
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = ['accuracy']
)
Iterative magnitude pruning by definition takes quite a long time as the model must be trained many times over, so any speedup would be awesome.