This commit is contained in:
parent
57f7a0607f
commit
6d89ffe182
23
神经网络/main.py
23
神经网络/main.py
|
|
@ -6,6 +6,7 @@
|
||||||
# 导入模块
|
# 导入模块
|
||||||
from typing import List, Literal, Optional, Dict, Tuple
|
from typing import List, Literal, Optional, Dict, Tuple
|
||||||
import numpy
|
import numpy
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
class NeuralNetwork:
|
class NeuralNetwork:
|
||||||
|
|
@ -144,6 +145,7 @@ class NeuralNetwork:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"输入和真实输出应为数组,其中输入维度应为[输入神经元数, 样本数],真实输出维度应为[输出神经元数, 样本数],样本数应需相同"
|
"输入和真实输出应为数组,其中输入维度应为[输入神经元数, 样本数],真实输出维度应为[输出神经元数, 样本数],样本数应需相同"
|
||||||
)
|
)
|
||||||
|
self.training = True
|
||||||
# 归一化输入
|
# 归一化输入
|
||||||
self.parameters[0].update({"activation": self._normalize(input=X)})
|
self.parameters[0].update({"activation": self._normalize(input=X)})
|
||||||
|
|
||||||
|
|
@ -172,6 +174,13 @@ class NeuralNetwork:
|
||||||
|
|
||||||
epoch += 1
|
epoch += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open("neural_network.pkl", "wb") as file:
|
||||||
|
pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
print(f"模型保存成功")
|
||||||
|
except Exception as exception:
|
||||||
|
raise RuntimeError(f"模型保存失败:{str(exception)}") from exception
|
||||||
|
|
||||||
def _normalize(
|
def _normalize(
|
||||||
self,
|
self,
|
||||||
input: numpy.ndarray,
|
input: numpy.ndarray,
|
||||||
|
|
@ -376,6 +385,18 @@ class NeuralNetwork:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _reason(self, input: numpy.ndarray) -> numpy.ndarray:
|
||||||
|
"""
|
||||||
|
推理
|
||||||
|
:param input: 输入,维度为[输入神经元数, 样本数]
|
||||||
|
:return: 输出,维度为[输出神经元数, 样本数]
|
||||||
|
"""
|
||||||
|
self.training = False
|
||||||
|
# 归一化输入
|
||||||
|
self.parameters[0].update({"activation": self._normalize(input=input)})
|
||||||
|
|
||||||
|
return self._forward_propagate()
|
||||||
|
|
||||||
|
|
||||||
# 测试代码
|
# 测试代码
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -391,5 +412,5 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# 训练
|
# 训练
|
||||||
neural_network.train(
|
neural_network.train(
|
||||||
X=X, y_true=y_true, target_loss=0.01, epochs=50_000, learning_rate=0.05
|
X=X, y_true=y_true, target_loss=0.1, epochs=1_000, learning_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue