How to extract the features from a specific layer from a pre-trained PyTorch model (such as ResNet or VGG), without doing a forward pass again?
Asked
Active
Viewed 1.7k times
2 Answers
12
New answer
Edit: there's a new feature in torchvision v0.11.0 that allows extracting features.
For example, if you wanna extract features from the layer layer4.2.relu_2
, you can do like:
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor
x = torch.rand(1, 3, 224, 224)
model = resnet50()
return_nodes = {
"layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)
Old answer
You can register a forward hook on the specific layer you want. Something like:
def some_specific_layer_hook(module, input_, output):
pass # the value is in 'output'
model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
model(some_input)
For example, to obtain the res5c
output in ResNet, you may want to use a nonlocal
variable (or global
in Python 2):
res5c_output = None
def res5c_hook(module, input_, output):
nonlocal res5c_output
res5c_output = output
resnet.layer4.register_forward_hook(res5c_hook)
resnet(some_input)
# Then, use `res5c_output`.

bryant1410
- 5,540
- 4
- 39
- 40
-
1How does the value of output get returned here?? hook functions are not allowed to have a return value, so I don't see how fc1000_output in your code will get the value of output assigned to it. How does the value res5c_output get passed to fc1000_output? – Kai Jan 27 '19 at 17:16
-
2I added a missing `nonlocal` declaration. It's not that the value of `res5c_output` gets passed to `fc1000_output`, it's that the former variable is bound to the outer context. – bryant1410 Jan 28 '19 at 19:10
3
The accepted answer is very helpful! I'm posting a complete example here (using a registered hook as described by @bryant1410) for the lazy ones looking for a working solution:
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
def get_feat_vector(path_img, model):
'''
Input:
path_img: string, /path/to/image
model: a pretrained torch model
Output:
my_output: torch.tensor, output of avgpool layer
'''
input_image = Image.open(path_img)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
my_output = None
def my_hook(module_, input_, output_):
nonlocal my_output
my_output = output_
a_hook = model.avgpool.register_forward_hook(my_hook)
model(input_batch)
a_hook.remove()
return my_output
There you have your features extraction function, simply call it using the snippet below to obtain features from resnet18.avgpool
layer
model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)

Zahra
- 6,798
- 9
- 51
- 76