This commit is contained in:
parent
57f7a0607f
commit
6d89ffe182
23
神经网络/main.py
23
神经网络/main.py
|
|
@ -6,6 +6,7 @@
|
|||
# 导入模块
|
||||
from typing import List, Literal, Optional, Dict, Tuple
|
||||
import numpy
|
||||
import pickle
|
||||
|
||||
|
||||
class NeuralNetwork:
|
||||
|
|
@ -144,6 +145,7 @@ class NeuralNetwork:
|
|||
raise RuntimeError(
|
||||
"输入和真实输出应为数组,其中输入维度应为[输入神经元数, 样本数],真实输出维度应为[输出神经元数, 样本数],样本数应需相同"
|
||||
)
|
||||
self.training = True
|
||||
# 归一化输入
|
||||
self.parameters[0].update({"activation": self._normalize(input=X)})
|
||||
|
||||
|
|
@ -172,6 +174,13 @@ class NeuralNetwork:
|
|||
|
||||
epoch += 1
|
||||
|
||||
try:
|
||||
with open("neural_network.pkl", "wb") as file:
|
||||
pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
print(f"模型保存成功")
|
||||
except Exception as exception:
|
||||
raise RuntimeError(f"模型保存失败:{str(exception)}") from exception
|
||||
|
||||
def _normalize(
|
||||
self,
|
||||
input: numpy.ndarray,
|
||||
|
|
@ -376,6 +385,18 @@ class NeuralNetwork:
|
|||
}
|
||||
)
|
||||
|
||||
def _reason(self, input: numpy.ndarray) -> numpy.ndarray:
|
||||
"""
|
||||
推理
|
||||
:param input: 输入,维度为[输入神经元数, 样本数]
|
||||
:return: 输出,维度为[输出神经元数, 样本数]
|
||||
"""
|
||||
self.training = False
|
||||
# 归一化输入
|
||||
self.parameters[0].update({"activation": self._normalize(input=input)})
|
||||
|
||||
return self._forward_propagate()
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -391,5 +412,5 @@ if __name__ == "__main__":
|
|||
|
||||
# 训练
|
||||
neural_network.train(
|
||||
X=X, y_true=y_true, target_loss=0.01, epochs=50_000, learning_rate=0.05
|
||||
X=X, y_true=y_true, target_loss=0.1, epochs=1_000, learning_rate=0.05
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue