This commit is contained in:
parent
d8737aa2bd
commit
91b272a658
18
神经网络/main.py
18
神经网络/main.py
|
|
@ -237,7 +237,7 @@ class NeuralNetwork:
|
|||
def _calculate_loss(
|
||||
self,
|
||||
y_true: numpy.ndarray,
|
||||
) -> numpy.floating:
|
||||
) -> numpy.floating: # pyright: ignore[reportReturnType]
|
||||
"""
|
||||
计算损失
|
||||
:param y_true: 真实输出,维度为[输出神经元数, 样本数]
|
||||
|
|
@ -280,10 +280,6 @@ class NeuralNetwork:
|
|||
axis=0,
|
||||
)
|
||||
) # 若输出层的激活函数为softmax则损失函数使用交叉熵
|
||||
case _:
|
||||
raise RuntimeError(
|
||||
f"该激活函数 {self.parameters[self.layer_counts]["activate"]} 暂不支持"
|
||||
)
|
||||
|
||||
def _backward_propagate(
|
||||
self,
|
||||
|
|
@ -300,9 +296,11 @@ class NeuralNetwork:
|
|||
{
|
||||
"delta_activation": (
|
||||
delta_activation := (
|
||||
(self.parameters[self.layer_counts]["activation"] - y_true)
|
||||
(self.parameters[layer_index]["activation"] - y_true)
|
||||
/ sample_counts
|
||||
# 若为输出层且激活函数为linear则直接计算输出的梯度,若为softmax则简化计算输出的梯度
|
||||
if layer_index == self.layer_counts
|
||||
# 若为隐含层则基于下一层的权重转置和加权输入的梯度计算当前层的输出梯度
|
||||
else numpy.dot(
|
||||
self.parameters[layer_index + 1]["weight"].T,
|
||||
self.parameters[layer_index + 1][
|
||||
|
|
@ -310,7 +308,7 @@ class NeuralNetwork:
|
|||
],
|
||||
)
|
||||
)
|
||||
), # 若为输出层则直接计算输出的梯度,否则基于下一层的权重转置和加权输入的梯度计算当前层的输出梯度
|
||||
),
|
||||
"delta_weighted_input": (
|
||||
delta_weighted_input := delta_activation
|
||||
* self._activate_derivative(
|
||||
|
|
@ -322,7 +320,7 @@ class NeuralNetwork:
|
|||
delta_weighted_input,
|
||||
(self.parameters[layer_index - 1]["activation"]).T,
|
||||
), # 权重的梯度
|
||||
"delta_bias": numpy.sum(
|
||||
"delta_bias": numpy.mean(
|
||||
delta_weighted_input,
|
||||
axis=1,
|
||||
keepdims=True,
|
||||
|
|
@ -381,10 +379,10 @@ if __name__ == "__main__":
|
|||
|
||||
# 创建并训练神经网络
|
||||
neural_network = NeuralNetwork(
|
||||
structure=[2, 256, 128, 1], # 2输入,10隐藏神经元,1输出
|
||||
structure=[2, 64, 32, 1], # 2输入,10隐藏神经元,1输出
|
||||
)
|
||||
|
||||
# 训练
|
||||
neural_network.train(
|
||||
X=X, y_true=y_true, target_loss=0.05, epochs=1_000, learning_rate=0.05
|
||||
X=X, y_true=y_true, target_loss=0.01, epochs=50_000, learning_rate=0.05
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue