I implemented it according to the formula of error backpropagation method. Is there any problem
import numpy as np
# 定义神经网络的参数
# 输入层3个,隐藏层4个,输出层1个
W2 = np.random.randn(4, 3) # 第二层权重矩阵
B2 = np.random.randn(4, 1) # 第二层的偏置
W3 = np.random.randn(1, 4) # 第三层权重矩阵
B3 = np.random.randn(1, 1)
# 定义激活函数sigmoid函数
def sigmoid(X):
return 1 / (1 + np.exp(-X))
def sigmoid_derivative(X):
return sigmoid(X) * (1 - sigmoid(X))
# 定义神经网络的前向传播函数
def forward(X):
# 第2层
Z2 = np.dot(W2, X) + np.tile(B2, (1, X.shape[1]))
A2 = sigmoid(Z2)
# 第3层
Z3 = np.dot(W3, A2) + np.tile(B3, (1, X.shape[1]))
A3 = sigmoid(Z3)
return A3
# 定义损失值
def loss(y, y_hat):
m = len(y)
return np.sum((y - y_hat) ** 2) / (2 * m)
# 定义损失函数的偏导数
def loss_derivative(y, y_hat):
return y - y_hat
# 定义反向传播函数
def backward(X, y, y_hat):
Z2 = np.dot(W2, X) # m * 100
A2 = sigmoid(Z2) # m * 100
Z3 = np.dot(W3, A2) # k * 100
# 第3层, J(a)*g(z3)
delta3 = (y - y_hat) * sigmoid_derivative(Z3) # k * 100
# 第2层, (δ2=(W3.T)*δ3)**g(z2)
delta2 = np.dot(W3.T, delta3) * sigmoid_derivative(Z2) # m * 100
# 第3层, dw3=δ3*(a2.T), db3=δ3
dW3 = np.dot(delta3, A2.T)
db3 = np.sum(delta3, axis=1, keepdims=True)
# 第2层, dw2=δ2*(a1.T), db2=δ2
dW2 = np.dot(delta2, X.T)
db2 = np.sum(delta2, axis=1, keepdims=True)
return dW3, dW2, db3, db2
# 定义sigmoid函数的导数
def sigmoid_grad(x):
return sigmoid(x) * (1 - sigmoid(x))
# 训练数据
X = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
y = np.array([[0], [0], [1], [1], [1]])
# 定义学习率
learning_rate = 0.1
# 训练神经网络
iteratorNum = 10 * 100
J_history = np.zeros(iteratorNum)
for i in range(iteratorNum):
# 前向传播
y_hat = forward(X.T)
# 计算损失函数
J_history[i] = loss(y, y_hat)
# 反向传播
dW3, dW2, dB3, dB2 = backward(X.T, y.T, y_hat)
# 更新参数
W3 -= learning_rate * dW3
B3 -= learning_rate * dB3
W2 -= learning_rate * dW2
B2 -= learning_rate * dB2
print(J_history)
case1 = np.array([[0, 0, 0], [1, 1, 0]])
print(forward(case1.T))
the end loss is 0.99999444. the error loss because biger. what is the reason
[0.99999401 0.99999403 0.99999404 0.99999406 0.99999408 0.99999409 0.99999411 0.99999413 0.99999414 0.99999416 0.99999417 0.99999419 0.99999421 0.99999422 0.99999424 0.99999425 0.99999427 0.99999428 0.9999943 0.99999432 0.99999433 0.99999435 0.99999436 0.99999438 0.99999439 0.99999441 0.99999442 0.99999444]
the res is [[0.9999975 0.99999866]]