I have a class A
that defines all my networks. I am wrapping this with torch.nn.DataParallel
. When I call the forward function as a()
, it works fine. However, I also want to call some other functions of A
, while still retaining the DataParallel
functionality. Is this possible? Or do I have to go through the forward function only?
Minimum Non-Working Example (Just to convey the context better):
class A(torch.nn.module)
def __init__():
blah blah blah
def forward(some_arguments):
blah blah blah
def func1(some_arguments):
blah blah blah
a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs) # works fine.
# calling func1
outputs1 = a.func1(inputs) # does not work.
outputs1 = a.module.func1(inputs) # works without parallelizing data. I am not sure if this is the right thing to do