This commit is contained in:
liubiren 2026-02-01 19:24:26 +08:00
parent 57f7a0607f
commit 6d89ffe182
1 changed files with 22 additions and 1 deletions

View File

@ -6,6 +6,7 @@
# 导入模块
from typing import List, Literal, Optional, Dict, Tuple
import numpy
import pickle
class NeuralNetwork:
@ -144,6 +145,7 @@ class NeuralNetwork:
raise RuntimeError(
"输入和真实输出应为数组,其中输入维度应为[输入神经元数, 样本数],真实输出维度应为[输出神经元数, 样本数],样本数应需相同"
)
self.training = True
# 归一化输入
self.parameters[0].update({"activation": self._normalize(input=X)})
@ -172,6 +174,13 @@ class NeuralNetwork:
epoch += 1
try:
with open("neural_network.pkl", "wb") as file:
pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL)
print(f"模型保存成功")
except Exception as exception:
raise RuntimeError(f"模型保存失败:{str(exception)}") from exception
def _normalize(
self,
input: numpy.ndarray,
@ -376,6 +385,18 @@ class NeuralNetwork:
}
)
def _reason(self, input: numpy.ndarray) -> numpy.ndarray:
"""
推理
:param input: 输入维度为[输入神经元数, 样本数]
:return: 输出维度为[输出神经元数, 样本数]
"""
self.training = False
# 归一化输入
self.parameters[0].update({"activation": self._normalize(input=input)})
return self._forward_propagate()
# 测试代码
if __name__ == "__main__":
@ -391,5 +412,5 @@ if __name__ == "__main__":
# 训练
neural_network.train(
X=X, y_true=y_true, target_loss=0.01, epochs=50_000, learning_rate=0.05
X=X, y_true=y_true, target_loss=0.1, epochs=1_000, learning_rate=0.05
)