-
大小: 8KB文件類型: .py金幣: 1下載: 0 次發布日期: 2023-07-19
- 語言: Python
- 標簽: vgg16??pytorchvgg16??
資源簡介
pytorch1.5實現的vgg16分類。在真實數據集測試成功
pytorch1.5實現的vgg16分類。在真實數據集測試成功
pytorch1.5實現的vgg16分類。在真實數據集測試成功
代碼片段和文件信息
#?-*-?coding:?utf-8?-*-
import?matplotlib.pyplot?as?plt
import?torch
import?numpy?as?np
from?torch?import?nn
from?torch?import?optim
from?torch.autograd?import?Variable
from?torchvision?import?datasetstransformsmodels
import?seaborn?as?sb
from?collections?import?OrderedDict
class?classifier():
????def?__init__(self):
?????????#?設置數據目錄
????????self.train_dir?=?‘smoke_clas/train‘
????????self.valid_dir?=?‘smoke_clas/valid‘
????????self.test_dir?=?‘smoke_clas/test‘
????#使用dataloader數據,驗證模型的準確率
????def?accuracy_valid(self?modeldataloader):
????????correct?=?0
????????total?=?0
????????model.cuda()?#?將模型放入GPU計算,能極大加快運算速度
????????with?torch.no_grad():?#?使用驗證集時關閉梯度計算
????????????for?data?in?dataloader:
???????????????
????????????????imageslabels?=?data
????????????????imageslabels?=?images.to(‘cuda‘)labels.to(‘cuda‘)
????????????????outputs?=?model(images)
????????????????_?predicted?=?torch.max(outputs.data1)?
????????????????#?torch.max返回輸出結果中,按dim=1行排列的每一行最大數據及他的索引,丟棄數據,保留索引
????????????????total?+=?labels.size(0)
????????????????
????????????????correct?+=?(predicted?==?labels).sum().item()
????????????????#將預測及標簽兩相同大小張量逐一比較各相同元素的個數
????????print(‘the?accuracy?is?{:.4f}‘.format(correct/total))
????def?accuracy_test(self?modeldataloader):
????????correct?=?0
????????total?=?0
????????model.cuda()?#?將模型放入GPU計算,能極大加快運算速度
????????with?torch.no_grad():?#?使用驗證集時關閉梯度計算
????????????for?data?in?dataloader:
???????????????
????????????????imageslabels?=?data
????????????????imageslabels?=?images.to(‘cuda‘)labels.to(‘cuda‘)
????????????????print(‘label:‘labels)
????????????????outputs?=?model(images)
????????????????probs=?[]
????????????????classes?=?[]
????????????????#print(outputs.data)
????????????????a?=?outputs[0]?????#?返回TOPK函數截取的排名前列的結果列表a
????????????????b?=?outputs[1].tolist()?#返回TOPK函數截取的排名前列的概率索引列表b
????????????????print(‘----------------‘)
????????????????print(‘a:‘a[0])
????????????????print(‘b:‘b[0])
????????????????for?i?in?a:
???????????????????print(torch.exp(i).tolist())
???????????????????probs.append(torch.exp(i).tolist())??#將結果轉化為實際概率
????????????????for?n?in?b:
???????????????????classes.append(str(n+1))??????#?將索引轉化為實際編號
????????????????print(classes)
????????????????print(probs)
????????????????_?predicted?=?torch.max(outputs.data1)?
????????????????print(‘label:‘labels)
????????????????print(‘predicted:‘predicted)
????????????????#print(predicted)
????def?deep_learning(self?modeltrainloaderepochsprint_everycriterionoptimizerdevicevalidloader):
????????epochs?=?epochs?#設置學習次數
????????print_every?=?print_every
????????steps?=?0
????????model.to(device)
????????
????????for?e?in?range(epochs):
????????????running_loss?=?0
????????????for?ii??(inputslabels)?in?enumerate(trainloader):
????????????????steps?+=?1
????????????????#inputs表示輸入的數據,labels表示每個數據對應的真實分類標簽
????????????????inputslabels?=?inputs.to(device)lab
- 上一篇:計算機視覺素材.part1.rar
- 下一篇:PYTHON自然語言處理_超高清pdf
評論
共有 條評論