-
大小:文件類型: .gz金幣: 1下載: 0 次發(fā)布日期: 2023-06-16
- 語(yǔ)言: 其他
- 標(biāo)簽: 深度學(xué)習(xí)??
資源簡(jiǎn)介
《動(dòng)手學(xué)深度學(xué)習(xí)》(Dive into Deep Learning)原書中的MXNet實(shí)現(xiàn)改為PyTorch實(shí)現(xiàn)。
代碼片段和文件信息
import?collections
import?math
import?os
import?random
import?sys
import?tarfile
import?time
import?zipfile
from?tqdm?import?tqdm
from?collections?import?namedtuple
from?IPython?import?display
from?matplotlib?import?pyplot?as?plt
import?torch
from?torch?import?nn
import?torch.nn.functional?as?F
import?torchvision
import?torchvision.transforms?as?transforms
import?torchtext
import?torchtext.vocab?as?Vocab
import?numpy?as?np
VOC_CLASSES?=?[‘background‘?‘a(chǎn)eroplane‘?‘bicycle‘?‘bird‘?‘boat‘
???????????????‘bottle‘?‘bus‘?‘car‘?‘cat‘?‘chair‘?‘cow‘
???????????????‘diningtable‘?‘dog‘?‘horse‘?‘motorbike‘?‘person‘
???????????????‘potted?plant‘?‘sheep‘?‘sofa‘?‘train‘?‘tv/monitor‘]
VOC_COLORMAP?=?[[0?0?0]?[128?0?0]?[0?128?0]?[128?128?0]
????????????????[0?0?128]?[128?0?128]?[0?128?128]?[128?128?128]
????????????????[64?0?0]?[192?0?0]?[64?128?0]?[192?128?0]
????????????????[64?0?128]?[192?0?128]?[64?128?128]?[192?128?128]
????????????????[0?64?0]?[128?64?0]?[0?192?0]?[128?192?0]
????????????????[0?64?128]]
#?######################?3.2?############################
def?set_figsize(figsize=(3.5?2.5)):
????use_svg_display()
????#?設(shè)置圖的尺寸
????plt.rcParams[‘figure.figsize‘]?=?figsize
def?use_svg_display():
????“““Use?svg?format?to?display?plot?in?jupyter“““
????display.set_matplotlib_formats(‘svg‘)
def?data_iter(batch_size?features?labels):
????num_examples?=?len(features)
????indices?=?list(range(num_examples))
????random.shuffle(indices)??#?樣本的讀取順序是隨機(jī)的
????for?i?in?range(0?num_examples?batch_size):
????????j?=?torch.LongTensor(indices[i:?min(i?+?batch_size?num_examples)])?#?最后一次可能不足一個(gè)batch
????????yield??features.index_select(0?j)?labels.index_select(0?j)?
def?linreg(X?w?b):
????return?torch.mm(X?w)?+?b
def?squared_loss(y_hat?y):?
????#?注意這里返回的是向量?另外?pytorch里的MSELoss并沒(méi)有除以?2
????return?((y_hat?-?y.view(y_hat.size()))?**?2)?/?2
def?sgd(params?lr?batch_size):
????#?為了和原書保持一致,這里除以了batch_size,但是應(yīng)該是不用除的,因?yàn)橐话阌肞yTorch計(jì)算loss時(shí)就默認(rèn)已經(jīng)
????#?沿batch維求了平均了。
????for?param?in?params:
????????param.data?-=?lr?*?param.grad?/?batch_size?#?注意這里更改param時(shí)用的param.data
#?######################3#####?3.5?#############################
def?get_fashion_mnist_labels(labels):
????text_labels?=?[‘t-shirt‘?‘trouser‘?‘pullover‘?‘dress‘?‘coat‘
???????????????????‘sandal‘?‘shirt‘?‘sneaker‘?‘bag‘?‘a(chǎn)nkle?boot‘]
????return?[text_labels[int(i)]?for?i?in?labels]
def?show_fashion_mnist(images?labels):
????use_svg_display()
????#?這里的_表示我們忽略(不使用)的變量
????_?figs?=?plt.subplots(1?len(images)?figsize=(12?12))
????for?f?img?lbl?in?zip(figs?images?labels):
????????f.imshow(img.view((28?28)).numpy())
????????f.set_title(lbl)
????????f.axes.get_xaxis().set_visible(False)
????????f.axes.get_yaxis().set_visible(False)
????#?plt.show()
#?5.6?修改
#?def?load_data_fashion_mnist(batch_size?root=‘~/Datase
- 上一篇:29種常用的運(yùn)算放大器-2018
- 下一篇:編譯原理_第2版_張素琴
評(píng)論
共有 條評(píng)論