While trying to script Stable Diffusion model using torch.jit.script()
, I got this following error:
AttributeError: 'TimestepEmbedSequential' object has no attribute '__globals__'
I'm trying to export this model to ONNX and found out that running torch.onnx.export()
will torch.jit.trace
the models, which unrolls every loops, so I'm trying to use script
first.
When I follow the traceback, the error occurs in this function while reading fn.__globals__
in _jit_internal.py
from torch
def get_closure(fn):
"""
Get a dictionary of closed over variables from a function
"""
captures = {}
captures.update(fn.__globals__)
for index, captured_name in enumerate(fn.__code__.co_freevars):
captures[captured_name] = fn.__closure__[index].cell_contents
return captures
code for scripting is as follows:
stablediffusion_wrapper = StableDiffusionWrapper(model, sampler, opt)
scripted_module = torch.jit.script(stablediffusion_wrapper, example_inputs=[(token_dummy, img_dummy)])
and TimestepEmbedSequential
module:
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
Any suggestions how can I figure it out?
I tried to set @torch.jit.export
decorator to the parent class TimestepBlock
, but showed no effect. Actually, I have no idea what to look for this problem. I would appreciate any suggestions. Thank you