I have a question regarding Chainer's iterator interface, and how it interfaces with a Trainer, Updater and Model.
My data are graphs, and hence are of varying matrix shapes. I have concatenated the features matrix into one large dense matrix, adjacency matrices into one large sparse COO matrix, and the summation operator into one large sparse COO matrix. Because this is done with molecular data, I have an atom graph and a bond graph per sample. Hence, the input data is a six-tuple, which for the purposes of deep learning, I consider to be one big giant data point for training. (Until I get this working with this giant matrix, I am not planning on doing train/test splits just yet, to keep my code simple.)
xs = (atom_Fs, atom_As, atom_Ss, bond_Fs, bond_As, bond_Ss)
ts = data['target'].values
dataset = [(xs, ts)]
My model forward pass is as follows:
# model boilerplate above this comment
def forward(self, data):
atom_feats, atom_adjs, atom_sums, bond_feats, bond_adjs, bond_sums = data
atom_feats = self.atom_mp1(atom_feats, atom_adjs)
atom_feats = self.atom_mp2(atom_feats, atom_adjs)
atom_feats = self.atom_gather(atom_feats, atom_sums)
bond_feats = self.atom_mp1(bond_feats, bond_adjs)
bond_feats = self.atom_mp2(bond_feats, bond_adjs)
bond_feats = self.atom_gather(bond_feats, bond_sums)
feats = F.hstack([atom_feats, bond_feats])
feats = F.tanh(self.dense1(feats))
feats = F.tanh(self.dense2(feats))
feats = self.dense3(feats)
return feats
I then pass everything into a Trainer:
from chainer import iterators, training
from chainer.optimizers import SGD, Adam
iterator = iterators.SerialIterator(dataset, batch_size=1)
optimizer = Adam()
optimizer.setup(mpnn)
updater = training.updaters.StandardUpdater(iterator, optimizer)
max_epoch = 50
trainer = training.Trainer(updater, (max_epoch, 'epoch'))
trainer.run()
However, when I run the trainer, I get the following error:
Exception in main training loop: forward() takes 2 positional arguments but 3 were given
Traceback (most recent call last):
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/trainer.py", line 315, in run
update()
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py", line 165, in update
self.update_core()
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py", line 177, in update_core
optimizer.update(loss_func, *in_arrays)
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/optimizer.py", line 680, in update
loss = lossfun(*args, **kwds)
File "/home/ericmjl/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/link.py", line 242, in __call__
out = forward(*args, **kwargs)
Will finalize trainer extensions and updater before reraising the exception.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-45-ea26cece43b3> in <module>
9 max_epoch = 50
10 trainer = training.Trainer(updater, (max_epoch, 'epoch'))
---> 11 trainer.run()
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
327 f.write('Will finalize trainer extensions and updater before '
328 'reraising the exception.\n')
--> 329 six.reraise(*sys.exc_info())
330 finally:
331 for _, entry in extensions:
~/anaconda/envs/mpnn/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
313 self.observation = {}
314 with reporter.scope(self.observation):
--> 315 update()
316 for name, entry in extensions:
317 if entry.trigger(self):
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py in update(self)
163
164 """
--> 165 self.update_core()
166 self.iteration += 1
167
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/training/updaters/standard_updater.py in update_core(self)
175
176 if isinstance(in_arrays, tuple):
--> 177 optimizer.update(loss_func, *in_arrays)
178 elif isinstance(in_arrays, dict):
179 optimizer.update(loss_func, **in_arrays)
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/optimizer.py in update(self, lossfun, *args, **kwds)
678 if lossfun is not None:
679 use_cleargrads = getattr(self, '_use_cleargrads', True)
--> 680 loss = lossfun(*args, **kwds)
681 if use_cleargrads:
682 self.target.cleargrads()
~/anaconda/envs/mpnn/lib/python3.7/site-packages/chainer/link.py in __call__(self, *args, **kwargs)
240 if forward is None:
241 forward = self.forward
--> 242 out = forward(*args, **kwargs)
243
244 # Call forward_postprocess hook
TypeError: forward() takes 2 positional arguments but 3 were given
This baffles me, because I have set up the dataset in the same way as the mnist example, in which the input data are paired with the output data in a tuple. Because of the layers of abstractions in Chainer, I'm not quite sure how to debug this issue. Does anybody have any insight into this?