Given a trained Tensorflow model (using a pruning approach that yields channel masks for conv2d layers), how can I remove whole channels from said model?
I created a custom layer that wraps an existing conv2d layer and adds a binary channel mask through some mechanism during training that essentially masks some of the channels so that they have no effect on the output:
import tensorflow as tf
class MyConv2DWrapper(tf.keras.layers.Layer):
def __init__(self, conv2d: tf.keras.layers.Conv2D, **kwargs):
super(MyConv2DWrapper, self).__init__(**kwargs)
self.conv2d = conv2d
self.channel_mask = None
def build(self, input_shape):
super(MyConv2DWrapper, self).build(input_shape)
self.channel_mask = self.add_weight(
name="channel_mask",
shape=(self.conv2d.filters,),
initializer=tf.keras.initializers.Constant(1),
trainable=False,
)
def call(self, inputs, **kwargs):
return self.conv2d(inputs, **kwargs) * self.channel_mask
Then I've trained the model and now have a model that contains unused channels, which I'd like to remove. Essentially I want the same model, but without these masked channels.
To that end, I've looked for potential tools (but everything I found was either dead / too old (like this tutorial or this library) or not quite clear to me (like this documentation page).
Hence, I've started to play around a bit myself and tried working on a function, I could pass to clone_model
, but find this non-trivial due to the fact that I need not only update those layers that I wrapped previously, but also their follow-up layers since the inputs dimensions of the latter ones will have changed.
What I have looks something like this:
def remove_unused_channels(model: tf.keras.Model) -> tf.keras.Model:
channel_mask = None
weights: dict[str, list[tf.Variable]] = {}
def drop_unused_channels(layer: tf.keras.layers.Layer) -> tf.keras.layers.Layer:
nonlocal channel_mask
nonlocal weights
if channel_mask is not None and not isinstance(layer, ChexConv2DWrapper):
config = layer.get_config()
new_layer = type(layer).from_config(config)
layer_weights = [w[..., np.where(channel_mask)[0]] for w in layer.get_weights()]
new_input_shape = (*layer.input_shape[:-1], int(sum(channel_mask)))
new_layer.build(new_input_shape)
weights[layer.name] = layer_weights
channel_mask = None
return new_layer
if isinstance(layer, ChexConv2DWrapper):
channel_mask = layer.channel_mask
layer_weights = layer.conv2d.get_weights()
# change dimensions to those of used weights
layer_weights = [w[..., np.where(layer.channel_mask)[0]] for w in layer_weights]
config = layer.conv2d.get_config()
n_filters = int(sum(channel_mask))
config["filters"] = n_filters
new_layer = tf.keras.layers.Conv2D.from_config(config)
new_layer.build(layer.conv2d.input_shape)
weights[layer.conv2d.name] = layer_weights
layer = new_layer
else:
weights[layer.name] = layer.get_weights()
return layer
model = tf.keras.models.clone_model(model, clone_function=drop_unused_channels)
for _layer in model.layers:
_layer.set_weights(weights[_layer.name])
return model
That seems to work fine for the pruned layers and their immediate followers (in case of simple layers), however if a follow-up-layer has an output that strictly depends on the input (e.g. BatchNormalization), this needs to be forwarded somehow to the even next layer. It's also non-trivial to me, what exactly should happen in the case of a non-sequential model architecture with skip connections for example.
Has anyone had a similar use case and / or any idea how to approach this further?