11

I need to avoid downloading the model from the web (due to restrictions on the machine installed).

This works, but it downloads the model from the Internet

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)

I have placed the .pth file and the hubconf.py file in the /tmp/ folder and changed my code to

model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')

but to my surprise, it still downloads the model from the Internet. What am I doing wrong? How can I load the model locally?

Just to give you a bit more details, I'm doing all this in a Docker container that has a read-only volume at runtime, so that's why the download of new files fails.

Peter Mortensen
  • 30,738
  • 21
  • 105
  • 131
coding-dude.com
  • 758
  • 1
  • 8
  • 27
  • It seems the option to load locally was not present in some earlier versions of PyTorch. Which version are you using? – GoodDeeds Apr 28 '21 at 16:17
  • 1
    Collecting torch==1.8.1 Downloading torch-1.8.1-cp38-cp38-manylinux1_x86_64.whl (804.1 MB) Collecting torchsummary==1.5.1 Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB) Collecting torchvision==0.9.1 Downloading torchvision-0.9.1-cp38-cp38-manylinux1_x86_64.whl (17.4 MB) – coding-dude.com Apr 28 '21 at 17:40
  • 2
    The code near `pretrained=True,s ource` does not seem to be syntactically correct. Is it in the original? – Peter Mortensen Feb 27 '22 at 00:45

3 Answers3

7

There are two approaches you can take to get a shippable model on a machine without an Internet connection.

  1. Load DeepLab with a pretrained model on a normal machine, use a JIT compiler to export it as a graph, and put it into the machine. The Script is easy to follow:

     # To export
     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
     traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
     traced_graph.save('DeepLab.pth')
    
     # To load
     model = torch.jit.load('DeepLab.pth').eval().to(device)
    

    In this case, the weights and network structure is saved as computational graph, so you won't need any extra files.

  2. Take a look at torchvision's GitHub repository.

    There's a download URL for DeepLabV3 with Resnet101 backbone weights.

    You can download those weights once, and then use deeplab from torchvision with pretrained=False flag and load weights manually.

     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
     model.load_state_dict(torch.load('downloaded weights path'))
    

    Take in consideration, there might be a ['state_dict'] or some similar parent key in state dict, where you would use:

     model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
    
Peter Mortensen
  • 30,738
  • 21
  • 105
  • 131
deepconsc
  • 511
  • 4
  • 4
  • 1
    what are H and W? – coding-dude.com Apr 29 '21 at 04:54
  • I assume it's the minimum 224 specified by the deeplab specs. When I try this the trace the error: `module._c._create_method_from_trace( RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.` – coding-dude.com Apr 29 '21 at 05:04
  • tried it with strict=False, seems to work so far. I'll spin up the docker environment and see if that works – coding-dude.com Apr 29 '21 at 05:14
  • H and W are Height and Width of the image (tensor) and you're right - minimum size required by DeepLabV3 is 224x224. load_state_dict(chkpt, strict=False) is right way to load only required weights. Final advice would be - test the model outputs by visualization, don't deploy it right away. – deepconsc Apr 29 '21 at 12:50
  • The jit failed because of DeepLab's output format: OrderedDict -> ['out', 'aux']. But, by passing strict=False to JIT tracer, it'll compile as a graph that outputs only ['out'] with size of [1, 21, H, W]. – deepconsc Apr 29 '21 at 12:59
  • Thank you so much! It worked so I marked this as the correct answer – coding-dude.com Apr 30 '21 at 12:23
4
model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)

This worked for me. Default source is github.

0

This woks for me:

    # model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
    model_path = '~/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth'
    model = deeplabv3_resnet101(pretrained=True)
    model.load_state_dict(torch.load(model_path))
    model.eval()
user941581
  • 389
  • 2
  • 4