資源簡介
使用python實(shí)現(xiàn)bp神經(jīng)網(wǎng)絡(luò)經(jīng)典代碼列子
代碼片段和文件信息
import?numpy?as?np
import?matplotlib.pyplot?as?plt
def?logsig(x):
????return?1/(1+np.exp(-x))
#人數(shù)(單位:萬人)
population=[20.5522.4425.3727.1329.4530.1030.9634.0636.4238.0939.1339.9941.9344.5947.3052.8955.7356.7659.1760.63]
#機(jī)動車數(shù)(單位:萬輛)
vehicle=[0.60.750.850.91.051.351.451.61.71.852.152.22.252.352.52.62.72.852.953.1]
#公路面積(單位:萬平方公里)
roadarea=[0.090.110.110.140.200.230.230.320.320.340.360.360.380.490.560.590.590.670.690.79]
#公路客運(yùn)量(單位:萬人)
passengertraffic=[512662177730914510460113871235315750183041983621024194902043322598251073344236836405484292743462]
#公路貨運(yùn)量(單位:萬噸)
freighttraffic=[123713791385139916631714183443228132893611099112031052411115133201676218673207242080321804]
samplein?=?np.mat([populationvehicleroadarea])?#3*20
sampleinminmax?=?np.array([samplein.min(axis=1).T.tolist()[0]samplein.max(axis=1).T.tolist()[0]]).transpose()#3*2,對應(yīng)最大值最小值
sampleout?=?np.mat([passengertrafficfreighttraffic])#2*20
sampleoutminmax?=?np.array([sampleout.min(axis=1).T.tolist()[0]sampleout.max(axis=1).T.tolist()[0]]).transpose()#2*2,對應(yīng)最大值最小值
#3*20
sampleinnorm?=?(2*(np.array(samplein.T)-sampleinminmax.transpose()[0])/(sampleinminmax.transpose()[1]-sampleinminmax.transpose()[0])-1).transpose()
#2*20
sampleoutnorm?=?(2*(np.array(sampleout.T).astype(float)-sampleoutminmax.transpose()[0])/(sampleoutminmax.transpose()[1]-sampleoutminmax.transpose()[0])-1).transpose()
#給輸出樣本添加噪音
noise?=?0.03*np.random.rand(sampleoutnorm.shape[0]sampleoutnorm.shape[1])
sampleoutnorm?+=?noise
maxepochs?=?60000
learnrate?=?0.035
errorfinal?=?0.65*10**(-3)
samnum?=?20
indim?=?3
outdim?=?2
hiddenunitnum?=?8
w1?=?0.5*np.random.rand(hiddenunitnumindim)-0.1
b1?=?0.5*np.random.rand(hiddenunitnum1)-0.1
w2?=?0.5*np.random.rand(outdimhiddenunitnum)-0.1
b2?=?0.5*np.random.rand(outdim1)-0.1
errhistory?=?[]
for?i?in?range(maxepochs):
????hiddenout?=?logsig((np.dot(w1sampleinnorm).transpose()+b1.transpose())).transpose()
????networkout?=?(np.dot(w2hiddenout).transpose()+b2.transpose()).transpose()
????err?=?sampleoutnorm?-?networkout
????sse?=?sum(sum(err**2))
????errhistory.append(sse)
????if?sse?????????break
????delta2?=?err
????delta1?=?np.dot(w2.transpose()delta2)*hiddenout*(1-hiddenout)
????dw2?=?np.dot(delta2hiddenout.transpose())
????db2?=?np.dot(delta2np.ones((samnum1)))
????dw1?=?np.dot(delta1sam
評論
共有 條評論