0

Given a trained Tensorflow model (using a pruning approach that yields channel masks for conv2d layers), how can I remove whole channels from said model?

I created a custom layer that wraps an existing conv2d layer and adds a binary channel mask through some mechanism during training that essentially masks some of the channels so that they have no effect on the output:

import tensorflow as tf

class MyConv2DWrapper(tf.keras.layers.Layer):

    def __init__(self, conv2d: tf.keras.layers.Conv2D, **kwargs):
        super(MyConv2DWrapper, self).__init__(**kwargs)
        self.conv2d = conv2d
        self.channel_mask = None

    def build(self, input_shape):
        super(MyConv2DWrapper, self).build(input_shape)
        self.channel_mask = self.add_weight(
            name="channel_mask",
            shape=(self.conv2d.filters,),
            initializer=tf.keras.initializers.Constant(1),
            trainable=False,
        )

    def call(self, inputs, **kwargs):
        return self.conv2d(inputs, **kwargs) * self.channel_mask

Then I've trained the model and now have a model that contains unused channels, which I'd like to remove. Essentially I want the same model, but without these masked channels.

To that end, I've looked for potential tools (but everything I found was either dead / too old (like this tutorial or this library) or not quite clear to me (like this documentation page).

Hence, I've started to play around a bit myself and tried working on a function, I could pass to clone_model, but find this non-trivial due to the fact that I need not only update those layers that I wrapped previously, but also their follow-up layers since the inputs dimensions of the latter ones will have changed.

What I have looks something like this:

def remove_unused_channels(model: tf.keras.Model) -> tf.keras.Model:
    channel_mask = None
    weights: dict[str, list[tf.Variable]] = {}

    def drop_unused_channels(layer: tf.keras.layers.Layer) -> tf.keras.layers.Layer:
        nonlocal channel_mask
        nonlocal weights

        if channel_mask is not None and not isinstance(layer, ChexConv2DWrapper):
            config = layer.get_config()
            new_layer = type(layer).from_config(config)
            layer_weights = [w[..., np.where(channel_mask)[0]] for w in layer.get_weights()]
            new_input_shape = (*layer.input_shape[:-1], int(sum(channel_mask)))
            new_layer.build(new_input_shape)
            weights[layer.name] = layer_weights
            channel_mask = None
            return new_layer

        if isinstance(layer, ChexConv2DWrapper):
            channel_mask = layer.channel_mask
            layer_weights = layer.conv2d.get_weights()
            # change dimensions to those of used weights
            layer_weights = [w[..., np.where(layer.channel_mask)[0]] for w in layer_weights]
            config = layer.conv2d.get_config()
            n_filters = int(sum(channel_mask))
            config["filters"] = n_filters
            new_layer = tf.keras.layers.Conv2D.from_config(config)
            new_layer.build(layer.conv2d.input_shape)
            weights[layer.conv2d.name] = layer_weights
            layer = new_layer
        else:
            weights[layer.name] = layer.get_weights()
        return layer

    model = tf.keras.models.clone_model(model, clone_function=drop_unused_channels)
    for _layer in model.layers:
        _layer.set_weights(weights[_layer.name])
    return model

That seems to work fine for the pruned layers and their immediate followers (in case of simple layers), however if a follow-up-layer has an output that strictly depends on the input (e.g. BatchNormalization), this needs to be forwarded somehow to the even next layer. It's also non-trivial to me, what exactly should happen in the case of a non-sequential model architecture with skip connections for example.

Has anyone had a similar use case and / or any idea how to approach this further?

0 Answers0