From the release highlight of pytorch 1.1.0. It appears that the latest JIT compiler now supports Dict type. (Source: https://jaxenter.com/pytorch-1-1-158332.html)
Dictionary and list support in TorchScript: Lists and dictionary types behave like Python lists and dictionaries.
Unfortunately I can't find a way to make this improvement to work properly. The following code is a simple example of exporting a Feature Pyramid Network (FPN) into tensorboard, which uses the JIT compiler:
from collections import OrderedDict
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
torchWriter = SummaryWriter(log_dir=".tensorboard/example1")
m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
# get some dummy data
x = OrderedDict()
x['feat0'] = torch.rand(1, 10, 64, 64)
x['feat2'] = torch.rand(1, 20, 16, 16)
x['feat3'] = torch.rand(1, 30, 8, 8)
# compute the FPN on top of x
output = m.forward(x)
print([(k, v.shape) for k, v in output.items()])
torchWriter.add_graph(m, input_to_model=x)
When I run it I got the following error:
Traceback (most recent call last):
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 276, in graph
trace, _ = torch.jit.get_trace_graph(model, args)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 231, in get_trace_graph
return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 284, in forward
in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got collections.OrderedDict
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/peng/git-drone/gate_detection/python/gate_detection/errorcase/tb.py", line 36, in <module>
torchWriter.add_graph(m, input_to_model=x)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 534, in add_graph
self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 279, in graph
_ = model(*args) # don't catch, just print the error message
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given
From the error message it appears that the support is still pending. Can I trust the release highlight? Or I'm not using the API properly?