Recently I've been brushing up on my machine learning, and as such decided to implement a basic neural network in Java using the back propagation algorithm. I've gone over the maths and checked against various other tutorials, but am still having problems. Apologies for the size of this post.
I'll first let you know the problems I have been testing on, before going into more detail about the algorithm.
Test problem 1:
A single output neuron with linear activation, learning regression of the function x/2 + 2
. This works pretty well, but doesn't really use back propagation yet.
Algorithm works, and converges to near zero error (no pic, since I can't post more than 2 links).
Test problem 2:
My next test was to learn the XOR problem. For this, I tried a simple network with 2 input nodes, 2 hidden nodes and 2 output nodes (input nodes only provide the input and are not trained).
Algorithm always gets stuck on an average of 0.5 error
It doesn't matter how many epochs I run the algorithm for, all errors seem to converge to this point, and the network performs poorly.
Implementation
To implement the algorithm, I represent nodes as objects, and also have objects to represent activation functions.
public class LogisticActivationFunction implements ActivationFunction {
@Override
public double apply(double in) {
return 1.0 / (1.0 + Math.exp(-in));
}
@Override
public double applyDerivative(double in) {
double sig = apply(in);
return sig * (1.0 - sig);
}
}
First, the feed forward process is run like so:
public List<Double> evaluate(List<Double> inputs, boolean training) {
// Set the weights in the first layer.
setInputWeights(inputs);
// Iterate through non-input layers one by one and evaluate.
NodeLayer previousLayer = layers.get(0);
for (int layerIndex = 1; layerIndex < layers.size(); layerIndex++) {
NodeLayer layer = layers.get(layerIndex);
for (int nodeIndex = 0; nodeIndex < layer.size(); nodeIndex++) {
Node node = layer.get(nodeIndex);
evaluateNode(node, previousLayer, training);
}
previousLayer = layer;
}
return getOutputWeights();
}
private void evaluateNode(Node node, NodeLayer previousLayer, boolean training) {
double sum = node.getBias();
// Create sum from all connected nodes.
for (int link : node.links()) {
if (training) {
previousLayer.get(link).registerDownstreamNode(node.getId());
}
sum += node.getUpstreamLinkStrength(link) * previousLayer.get(link).getOutput();
}
// apply the activation function.
double activation = node.getActivation().apply(sum);
node.setHiddenNode(sum, activation);
}
Next, error values are propagated backwards across the network:
protected void backPropogate(List<Double> correct) {
//float error = norm(correct, getOutputWeights());
// Final layer error.
NodeLayer outputLayer = getOutputLayer();
List<Double> output = getOutputWeights();
for (int i = 0; i < outputLayer.size(); i++) {
// Calculate error on the ith output.
double error = correct.get(i) - output.get(i);
System.out.println("error " + i + " = " + error + " = " + correct.get(i) + " - " + output.get(i));
// Set the delta to the error in dimension i multiplied by the activation derivative of the input.
Node node = outputLayer.get(i);
node.setDelta(error * node.getActivation().applyDerivative(node.getInput()));
}
NodeLayer layer = outputLayer.getUpstream(this);
while (layer != getInputLayer()) {
for (Node node : layer) {
double sum = 0;
for (Node downstream : node.downstreamNodes(this, layer)) {
sum += downstream.getDelta() * downstream.getUpstreamLinkStrength(node.getId());
}
node.setDelta(sum * node.getActivation().applyDerivative(node.getInput()));
}
layer = layer.getUpstream(this);
}
}
Finally, weights are updated using gradient descent. Note, I'm using the negative error, so this works by adding the delta * learning rate * output.
private void updateParameters(double learningRate) {
for (NodeLayer layer : this) {
if (layer == getInputLayer()) {
continue;
}
for (Node node : layer) {
double oldBias = node.getBias();
node.offsetBias(node.getDelta() * learningRate);
for (Node upstream : node.upstreamNodes(this, layer)) {
double oldW = node.getUpstreamLinkStrength(upstream);
node.offsetWeight(upstream.getId(), learningRate * node.getDelta() * upstream.getOutput());
}
}
}
}
To tie these all together, I use the train method:
public void trainExample(List<Double> inputs, List<Double> correct, double learningRate) {
System.out.println("training example... " + Data.toString(inputs) + " -> " + Data.toString(correct));
evaluate(inputs, true);
backPropogate(correct);
updateParameters(learningRate);
}
And to do this for a training set I use the following logic:
public List<Double> train(NodeNetwork network, List<List<Double>> trainingInput, List<List<Double>> trainingLabels, double learningRate, int epochs, boolean verbose) {
List<Double> errorLog = new ArrayList<>();
for (int i = 0; i < epochs; i++) {
for (int j = 0; j < trainingInput.size(); j++) {
int example = random.nextInt(trainingInput.size());
network.trainExample(trainingInput.get(example), trainingLabels.get(example), learningRate);
}
if (verbose) {
double error = network.checkErrorSet(trainingInput, trainingLabels);
errorLog.add(error);
System.out.println(i + " " + error);
}
}
return errorLog;
}
Does anyone have any ideas on how I might go about getting this to work? I've spent the last day doing various checks, and seem to be getting no closer to an answer.
Code is viewable on my github (sami016) which I cannot link due to URL restrictions.
I'd really appreciate if anyone could point me in the right direction. Thanks for your help!