This commit is contained in:
parent
e42b71ca5e
commit
33f8edcd86
BIN
parameters.pkl
BIN
parameters.pkl
Binary file not shown.
35
神经网络/main.py
35
神经网络/main.py
|
|
@ -4,7 +4,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 导入模块
|
# 导入模块
|
||||||
from typing import List, Literal, Optional, Dict, Tuple
|
from typing import List, Literal
|
||||||
import numpy
|
import numpy
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
@ -43,6 +43,8 @@ class NeuralNetwork:
|
||||||
raise RuntimeError("神经网络结构应为列表,长度大于等于3且元素均为正整数")
|
raise RuntimeError("神经网络结构应为列表,长度大于等于3且元素均为正整数")
|
||||||
# 初始化神经网络结构
|
# 初始化神经网络结构
|
||||||
self.structure = structure
|
self.structure = structure
|
||||||
|
# 神经网络层数(定义,第0层为输入层,第l层为隐含层(l=1,2,...,L-1),第L层为输出层(L为神经网络层数),深度为L+1)
|
||||||
|
self.layer_counts = len(structure) - 1
|
||||||
|
|
||||||
if hidden_activate not in self.HIDDEN_ACTIVATES:
|
if hidden_activate not in self.HIDDEN_ACTIVATES:
|
||||||
raise RuntimeError(f"该隐含层激活函数 {hidden_activate} 暂不支持")
|
raise RuntimeError(f"该隐含层激活函数 {hidden_activate} 暂不支持")
|
||||||
|
|
@ -51,9 +53,6 @@ class NeuralNetwork:
|
||||||
raise RuntimeError(f"该输出层激活函数 {output_activate} 暂不支持")
|
raise RuntimeError(f"该输出层激活函数 {output_activate} 暂不支持")
|
||||||
self.output_activate = output_activate
|
self.output_activate = output_activate
|
||||||
|
|
||||||
# 神经网络层数(定义,第0层为输入层,第l层为隐含层(l=1,2,...,L-1),第L层为输出层(L为神经网络层数),深度为L+1)
|
|
||||||
self.layer_counts = len(structure) - 1
|
|
||||||
|
|
||||||
# 初始化神经网络参数
|
# 初始化神经网络参数
|
||||||
self.parameters = {}
|
self.parameters = {}
|
||||||
|
|
||||||
|
|
@ -414,7 +413,17 @@ class NeuralNetwork:
|
||||||
"""
|
"""
|
||||||
with open("parameters.pkl", "wb") as file:
|
with open("parameters.pkl", "wb") as file:
|
||||||
pickle.dump(
|
pickle.dump(
|
||||||
obj=self.parameters,
|
obj={
|
||||||
|
layer_index: {
|
||||||
|
key: value
|
||||||
|
for key, value in layer_parameters.items()
|
||||||
|
if layer_index == 0
|
||||||
|
and key in ["mean", "variance"]
|
||||||
|
or layer_index != 0
|
||||||
|
and key in ["weight", "bias", "activate"]
|
||||||
|
}
|
||||||
|
for layer_index, layer_parameters in self.parameters.items()
|
||||||
|
},
|
||||||
file=file,
|
file=file,
|
||||||
protocol=pickle.HIGHEST_PROTOCOL,
|
protocol=pickle.HIGHEST_PROTOCOL,
|
||||||
)
|
)
|
||||||
|
|
@ -456,17 +465,25 @@ class NeuralNetwork:
|
||||||
== (self.structure[layer_index], self.structure[layer_index - 1])
|
== (self.structure[layer_index], self.structure[layer_index - 1])
|
||||||
and self.parameters[layer_index]["bias"].shape
|
and self.parameters[layer_index]["bias"].shape
|
||||||
== (self.structure[layer_index], 1)
|
== (self.structure[layer_index], 1)
|
||||||
|
and (
|
||||||
|
self.parameters[layer_index]["activate"] in self.output_activate
|
||||||
|
if layer_index == self.layer_counts
|
||||||
|
else self.parameters[layer_index]["activate"]
|
||||||
|
in self.hidden_activate
|
||||||
|
)
|
||||||
if isinstance(self.parameters[layer_index]["weight"], numpy.ndarray)
|
if isinstance(self.parameters[layer_index]["weight"], numpy.ndarray)
|
||||||
and isinstance(self.parameters[layer_index]["bias"], numpy.ndarray)
|
and isinstance(self.parameters[layer_index]["bias"], numpy.ndarray)
|
||||||
else False
|
else False
|
||||||
):
|
):
|
||||||
raise RuntimeError("神经网络参数中权重和偏置的维度与神经网络结构不匹配")
|
raise RuntimeError(
|
||||||
|
"神经网络参数中权重和偏置的维度与神经网络结构不匹配、或激活函数不匹配"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 测试代码
|
# 测试代码
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
X = numpy.random.randn(2, 5000)
|
X = numpy.random.randn(2, 1000)
|
||||||
# 真实函数:y = 2*x1 + 3*x2 + 1
|
# 真实函数:y = 2*x1 + 3*x2 + 1
|
||||||
y_true = 2 * X[0:1, :] ** 2 + 3 * X[1:2, :] + 1
|
y_true = 2 * X[0:1, :] ** 2 + 3 * X[1:2, :] + 1
|
||||||
|
|
||||||
|
|
@ -476,9 +493,7 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
# 训练
|
# 训练
|
||||||
neural_network.train(
|
#neural_network.train(X=X, y_true=y_true, target_loss=0.01, epochs=1000_000, learning_rate=0.05)
|
||||||
X=X, y_true=y_true, target_loss=0.01, epochs=1_000, learning_rate=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"推理结果:{y_true[:, 0:5]}")
|
print(f"推理结果:{y_true[:, 0:5]}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue