資源簡介
復旦大學顧曉東老師課程作業代碼,python實現:用BP、RBF、SVM實現三個函數擬合;代碼包括數據的產生,數據的輸入,訓練等
代碼片段和文件信息
import?torch
import??numpy?as?nn
import?matplotlib.pyplot?as?plt
import?torch.nn.functional?as?F
import??numpy?as?np
epoch?=?1000
def?train():
????print(‘------??????構建數據集??????------‘)
????x?=?torch.unsqueeze(torch.cat([torch.linspace(-10?-1?50)torch.linspace(-1?-0.01?50)
??????????????????????????????????torch.linspace(0.01?1?50)torch.linspace(1?10?50)])?dim=1)
????#y?=?np.sin(x)/np.abs(x)
????y?=?x
????#y1?=?y**(-1)
???#?y2?=?y1**(1/3)
????plt.scatter(x?y)
????x?=?x.cuda()
????y?=?y.cuda()
???#?plt.show()
????print(‘------??????搭建網絡??????------‘)
????#?使用固定的方式繼承并重寫?init和forword兩個類
????class?Net(torch.nn.Module):
????????def?__init__(self?n_feature?n_hidden1n_hidden2n_output):
????????????#?初始網絡的內部結構
????????????super(Net?self).__init__()
????????????self.hidden1?=?torch.nn.Linear(n_feature?n_hidden1)
????????????self.hidden2?=?torch.nn.Linear(n_hidden1?n_hidden2)
????????????self.predict?=?torch.nn.Linear(n_hidden2?n_output)
????????def?forward(self?x):
????????????#?一次正向行走過程
????????????x?=?torch.tanh(self.hidden1(x))
????????????x?=?torch.tanh(self.hidden2(x))
????????????x?=?self.predict(x)
????????????return?x
????net?=?Net(n_feature=1?n_hidden1=500n_hidden2=500?n_output=1)
????net?=?net.cuda()
????print(‘網絡結構為:‘?net)
????print(‘------??????啟動訓練??????------‘)
????loss_func?=?F.mse_loss
????optimizer?=?torch.optim.SGD(net.parameters()?lr=0.001)
????#?使用數據?進行正向訓練,并對Variable變量進行反向梯度傳播??啟動100次訓練
????for?i?in?range(epoch):
????????#?使用全量數據?進行正向行走
????????prediction?=?net(x)
????????loss?=?loss_func(prediction?y)
????????optimizer.zero_grad()??#?清除上一梯度
????????loss.backward()??#?反向傳播計算梯度
????????optimizer.step()??#?應用梯度
????????if?i?%?100==0:
????????????print(‘epoch:[%d/%d]loss:??%.6f‘?%?(i?epoch?loss))
????????#?間隔一段,對訓練過程進行可視化展示
????????if?i?%?5?==?0:
????????????plt.cla()
???????????#?x?=?x.cpu()
???????????#?y?=?y.cpu()
????????????plt.scatter(x.cpu()?y.cpu())
????????????plt.plot(x.detach().cpu().numpy()?prediction.detach().cpu().numpy()?‘r-‘?lw=5)
????????????plt.text(0.5?0?‘Loss=‘?+?str(loss.item())?fontdict={‘size‘:?20?‘color‘:?‘red‘})
????????????plt.pause(0.1)
????plt.ioff()
????plt.show()
if?__name__?==?‘__main__‘:
????train()
評論
共有 條評論