diff --git a/神经网络/main.py b/神经网络/main.py index e994bba..57cb471 100644 --- a/神经网络/main.py +++ b/神经网络/main.py @@ -6,6 +6,7 @@ # 导入模块 from typing import List, Literal, Optional, Dict, Tuple import numpy +import pickle class NeuralNetwork: @@ -144,6 +145,7 @@ class NeuralNetwork: raise RuntimeError( "输入和真实输出应为数组,其中输入维度应为[输入神经元数, 样本数],真实输出维度应为[输出神经元数, 样本数],样本数应需相同" ) + self.training = True # 归一化输入 self.parameters[0].update({"activation": self._normalize(input=X)}) @@ -172,6 +174,13 @@ class NeuralNetwork: epoch += 1 + try: + with open("neural_network.pkl", "wb") as file: + pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL) + print(f"模型保存成功") + except Exception as exception: + raise RuntimeError(f"模型保存失败:{str(exception)}") from exception + def _normalize( self, input: numpy.ndarray, @@ -376,6 +385,18 @@ class NeuralNetwork: } ) + def _reason(self, input: numpy.ndarray) -> numpy.ndarray: + """ + 推理 + :param input: 输入,维度为[输入神经元数, 样本数] + :return: 输出,维度为[输出神经元数, 样本数] + """ + self.training = False + # 归一化输入 + self.parameters[0].update({"activation": self._normalize(input=input)}) + + return self._forward_propagate() + # 测试代码 if __name__ == "__main__": @@ -391,5 +412,5 @@ if __name__ == "__main__": # 训练 neural_network.train( - X=X, y_true=y_true, target_loss=0.01, epochs=50_000, learning_rate=0.05 + X=X, y_true=y_true, target_loss=0.1, epochs=1_000, learning_rate=0.05 )