I have tried to implement a multi layer perceptron with sigmoid activations. Below is the code:
import numpy as np
def sigmoid(x):
return 1.0/(1.0 + np.exp(-x))
def sigmoid_derivative(x):
return sigmoid(x) * (1.0 - sigmoid(x))
class MLP:
def __init__(self, layers, x_train, y_train):
self.layers = layers
self.inputs = x_train
self.outputs = y_train
def forward(self, input):
output = input
for layer in self.layers:
layer.activations = output
output = layer.feedforward(output)
return output
def backward(self, output, predicted):
error = np.multiply(2 * np.subtract(output, predicted), sigmoid_derivative(predicted))
for layer in self.layers[::-1]:
#recursively backpropagate the error
error = layer.backpropagate(error)
def train(self):
for i in range(1,500):
predicted = self.forward(self.inputs)
self.backward(self.outputs,predicted)
def test(self, input):
return self.forward(input)
class Layer:
def __init__(self, prevNodes, selfNodes):
self.weights = np.random.rand(prevNodes,selfNodes)
self.biases = np.zeros(selfNodes)
self.activations = np.array([])
def feedforward(self, input):
return sigmoid(np.dot(input, self.weights) + self.biases)
def backpropagate(self, error):
delPropagate = np.dot(error, self.weights.transpose())
dw = np.dot(self.activations.transpose(), error)
db = error.mean(axis=0) * self.activations.shape[0]
self.weights = self.weights + 0.1 * dw
self.biases = self.biases + 0.1 * db
return np.multiply(delPropagate ,sigmoid_derivative(self.activations))
layer1 = Layer(3,4)
layer2 = Layer(4,1)
x_train = np.array([[0,0,1],[0,1,1],[1,0,1],[1,1,1]])
y_train = np.array([[0],[1],[1],[0]])
x_test = np.array([[0,0,1]])
mlp = MLP([layer1,layer2], x_train, y_train)
mlp.train()
mlp.test(x_test)
However the problem is the network saturates to output the average of all training outputs for any input. For eg, in the above case the y_train avarage is about 0.5 and no matter what 'test_x' value I feed to the network the output is always around the 0.5 mark.
Where could be the problem in code. Am I missing something in the algorithms. Help is appreciated