I am trying to create a neural network in PyBrain which utilizes shared weight connections, but am having trouble doing so. I haven't found too many examples which use these type of connections, but had thought that I had gleaned the way to use them from those that I found and the source code. But apparently I was not so lucky.
As a simple example, I'm trying to create the following shared weight neural network:
The connections shown in rectangles I would like to be shared, in that the weights along each path are the same (swapping the input vector [x,y] with [y,x] should yield the same output).
I attempted to build this architecture with the following code:
from pybrain.structure.modules.linearlayer import LinearLayer
from pybrain.structure.modules.sigmoidlayer import SigmoidLayer
from pybrain.structure.moduleslice import ModuleSlice
from pybrain.structure.networks.feedforward import FeedForwardNetwork
from pybrain.structure.connections.shared import MotherConnection,SharedFullConnection
net=FeedForwardNetwork()
# make modules
inp=LinearLayer(2,name='input')
h1=SigmoidLayer(2,name='hidden')
outp=LinearLayer(1,name='output')
# now add modules
net.addOutputModule(outp)
net.addInputModule(inp)
net.addModule(h1)
# now we need to create the connections
mc=MotherConnection(2,name='mother')
mc2=MotherConnection(2,name='mother2')
topInput=ModuleSlice(inp,outSliceFrom=0,outSliceTo=1)
bottomInput=ModuleSlice(inp,outSliceFrom=1,outSliceTo=2)
topHidden=ModuleSlice(h1,inSliceFrom=0,inSliceTo=1)
bottomHidden=ModuleSlice(h1,inSliceFrom=1,inSliceTo=2)
net.addConnection(SharedFullConnection(mc,topInput,topHidden))
net.addConnection(SharedFullConnection(mc,bottomInput,bottomHidden))
net.addConnection(SharedFullConnection(mc2,topHidden,outp))
net.addConnection(SharedFullConnection(mc2,bottomHidden,outp))
# finish up
net.sortModules()
#print net.activate([2,1])
In the above code I created two MotherConnections, mc and mc2, the idea being that these two objects would hold the shared weights in my first and second rectangle, respectively, shown in the figure. I then used ModuleSlice to divvy up the input module and the hidden module into two groups. I then tried to addConnections using the mc and mc2 containers to connect these paths.
I do not get an error by running the above code. But if I try to test the network by uncommenting the net.activate statement at the end, I get the following error:
Traceback (most recent call last):
File "test.py", line 38, in <module>
print net.activate([2,1])
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1-py2.7.egg/pybrain/structure/networks/feedforward.py", line 19, in activate
return super(FeedForwardNetworkComponent, self).activate(inpt)
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1- py2.7.egg/pybrain/structure/modules/module.py", line 123, in activate
self.forward()
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1-py2.7.egg/pybrain/structure/modules/module.py", line 75, in forward
self.outputbuffer[self.offset])
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1- py2.7.egg/pybrain/structure/networks/feedforward.py", line 32, in _forwardImplementation
c.forward()
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1-py2.7.egg/pybrain/structure/connections/connection.py", line 77, in forward
self.outmod.inputbuffer[outmodOffset, self.outSliceFrom:self.outSliceTo])
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1-py2.7.egg/pybrain/structure/connections/shared.py", line 64, in _forwardImplementation
FullConnection._forwardImplementation(self, inbuf, outbuf)
File "/usr/local/lib/python2.7/dist-packages/PyBrain-0.3.1-py2.7.egg/pybrain/structure/connections/full.py", line 19, in _forwardImplementation
outbuf += dot(reshape(self.params, (self.outdim, self.indim)), inbuf)
File "/usr/lib/python2.7/dist-packages/numpy/core/fromnumeric.py", line 171, in reshape
return reshape(newshape, order=order)
ValueError: total size of new array must be unchanged
so I imagine I must be misunderstanding something about the way this is set up. Many thanks to anyone who can point out where in my understanding of these commands I'm going astray!