1

Is there a way to load a pretrained model in Tensorflow and remove the top layers in the network? I am looking at Tensorflow release r1.10

The only documentation I could find is with tf.keras.Sequential.pop https://www.tensorflow.org/versions/r1.10/api_docs/python/tf/keras/Sequential#pop

I want to manually prune a pretrained network by removing bunch of top convolution layers and add a custom fully convoluted layer.

EDIT:

The model is ssd_mobilenet_v1_coco downloaded from Tensorflow Model Zoo. I have access to both the frozen_inference_graph.pb model file and checkpoint file.

I donot have access to the python code which is used to construct the model.

Thanks.

Anil Maddala
  • 898
  • 16
  • 34
  • You probably want to add some more detail to your question here. What do you mean by "pretrained model"? Is it a serialized binary (e.g., in the `SavedModel` format), or is it a checkpoint of model parameters? Do you have access to the Python code that constructs the model (and thus you can modify that)? Once you describe your problem more concretely, people might be able to chime in. – ash Aug 28 '18 at 00:18
  • 1
    @ash I added the details you asked for. Ler me know if you need further details. – Anil Maddala Aug 28 '18 at 00:37
  • SAME QUESTION answered here : https://stackoverflow.com/questions/50646426/how-to-use-pre-trained-models-without-classes-in-tensorflow – Tbertin Aug 28 '18 at 14:45
  • 1
    @Tbertin the approach you mentioned is just to remove the classification layer only. I want to remove other top conv2d layers aswell. – Anil Maddala Aug 28 '18 at 20:08

1 Answers1

1

From inspecting the code, SSDMobileNetV1FeatureExtractor.extract_features redirects research.slim.nets:

  from nets import mobilenet_v1  # nets will have to be on your PYTHONPATH

with tf.variable_scope('MobilenetV1',
                       reuse=self._reuse_weights) as scope:
  with slim.arg_scope(
      mobilenet_v1.mobilenet_v1_arg_scope(
          is_training=None, regularize_depthwise=True)):
    with (slim.arg_scope(self._conv_hyperparams_fn())
          if self._override_base_feature_extractor_hyperparams
          else context_manager.IdentityContextManager()):
      _, image_features = mobilenet_v1.mobilenet_v1_base(
          ops.pad_to_multiple(preprocessed_inputs, self._pad_to_multiple),
          final_endpoint='Conv2d_13_pointwise',
          min_depth=self._min_depth,
          depth_multiplier=self._depth_multiplier,
          use_explicit_padding=self._use_explicit_padding,
          scope=scope)

The mobilenet_v1_base function takes a final_endpoint argument. Rather than prune the constructed graph, just construct the graph up until the endpoint you want.

DomJack
  • 4,098
  • 1
  • 17
  • 32