0

While trying to script Stable Diffusion model using torch.jit.script(), I got this following error:

AttributeError: 'TimestepEmbedSequential' object has no attribute '__globals__'

I'm trying to export this model to ONNX and found out that running torch.onnx.export() will torch.jit.trace the models, which unrolls every loops, so I'm trying to use script first.

When I follow the traceback, the error occurs in this function while reading fn.__globals__ in _jit_internal.py from torch

def get_closure(fn):
    """
    Get a dictionary of closed over variables from a function
    """
    captures = {}
    captures.update(fn.__globals__)

    for index, captured_name in enumerate(fn.__code__.co_freevars):
        captures[captured_name] = fn.__closure__[index].cell_contents

    return captures

code for scripting is as follows:

stablediffusion_wrapper = StableDiffusionWrapper(model, sampler, opt)
scripted_module = torch.jit.script(stablediffusion_wrapper, example_inputs=[(token_dummy, img_dummy)])

and TimestepEmbedSequential module:

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x

Any suggestions how can I figure it out?

I tried to set @torch.jit.export decorator to the parent class TimestepBlock, but showed no effect. Actually, I have no idea what to look for this problem. I would appreciate any suggestions. Thank you

mactok
  • 1
  • 1

0 Answers0