5

I have a class A that defines all my networks. I am wrapping this with torch.nn.DataParallel. When I call the forward function as a(), it works fine. However, I also want to call some other functions of A, while still retaining the DataParallel functionality. Is this possible? Or do I have to go through the forward function only?

Minimum Non-Working Example (Just to convey the context better):

class A(torch.nn.module)
    def __init__():
        blah blah blah
    
    def forward(some_arguments):
        blah blah blah

    def func1(some_arguments):
        blah blah blah

a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs)  # works fine.
# calling func1
outputs1 = a.func1(inputs)  # does not work.
outputs1 = a.module.func1(inputs)  # works without parallelizing data. I am not sure if this is the right thing to do
Nagabhushan S N
  • 6,407
  • 8
  • 44
  • 87

1 Answers1

1

Have you tried calling func1 from inside forward, instead of externally? So essentially, you would call forward which would in turn call func1. In case you want to conditionally call func1, you can pass the function name as a parameter to forward. These suggestions are also present in this thread https://discuss.pytorch.org/t/dataparallel-model-with-custom-functions/75053/10