13

How to extract the features from a specific layer from a pre-trained PyTorch model (such as ResNet or VGG), without doing a forward pass again?

bryant1410
  • 5,540
  • 4
  • 39
  • 40

2 Answers2

12

New answer

Edit: there's a new feature in torchvision v0.11.0 that allows extracting features.

For example, if you wanna extract features from the layer layer4.2.relu_2, you can do like:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor

x = torch.rand(1, 3, 224, 224)

model = resnet50()

return_nodes = {
    "layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)

Old answer

You can register a forward hook on the specific layer you want. Something like:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'

model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
    
model(some_input)

For example, to obtain the res5c output in ResNet, you may want to use a nonlocal variable (or global in Python 2):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output

resnet.layer4.register_forward_hook(res5c_hook)

resnet(some_input)
    
# Then, use `res5c_output`.
bryant1410
  • 5,540
  • 4
  • 39
  • 40
  • 1
    How does the value of output get returned here?? hook functions are not allowed to have a return value, so I don't see how fc1000_output in your code will get the value of output assigned to it. How does the value res5c_output get passed to fc1000_output? – Kai Jan 27 '19 at 17:16
  • 2
    I added a missing `nonlocal` declaration. It's not that the value of `res5c_output` gets passed to `fc1000_output`, it's that the former variable is bound to the outer context. – bryant1410 Jan 28 '19 at 19:10
3

The accepted answer is very helpful! I'm posting a complete example here (using a registered hook as described by @bryant1410) for the lazy ones looking for a working solution:

import torch 
import torchvision.models as models
from torchvision import transforms
from PIL import Image

def get_feat_vector(path_img, model):
    '''
    Input: 
        path_img: string, /path/to/image
        model: a pretrained torch model
    Output:
        my_output: torch.tensor, output of avgpool layer
    '''
    input_image = Image.open(path_img)
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        my_output = None
        
        def my_hook(module_, input_, output_):
            nonlocal my_output
            my_output = output_

        a_hook = model.avgpool.register_forward_hook(my_hook)        
        model(input_batch)
        a_hook.remove()
        return my_output

There you have your features extraction function, simply call it using the snippet below to obtain features from resnet18.avgpool layer

model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)
Zahra
  • 6,798
  • 9
  • 51
  • 76