資源簡介
代碼是利用pytorch框架實現的,識別過程是利用循環神經網絡RNN進行訓練。
代碼片段和文件信息
import?torch
from?torch?import?nnoptim
from?torch.autograd?import?Variable
from?torch.utils.data?import?DataLoader
from?torchvision?import?datasetstransforms
#?超參數
batch_size?=?100????#?批大小
learning_rate?=?0.01?#?學習率
num_epoches?=?20??#?訓練次數
data_tf?=?transforms.Compose([transforms.ToTensor()transforms.Normalize([0.5][0.5])])
train_dataset?=?datasets.MNIST(root=‘./data‘train=Truetransform=data_tfdownload=True)
test_dataset?=?datasets.MNIST(root=‘./data‘train=Falsetransform=data_tf)
train_loader?=?DataLoader(train_datasetbatch_size=batch_sizeshuffle=True)
test_loader?=?DataLoader(test_datasetbatch_size=batch_sizeshuffle=False)
class?Rnn(nn.Module):
????def?__init__(self?in_dim?hidden_dim?n_layer?n_class):
????????super(Rnn?self).__init__()
????????self.n_layer?=?n_layer
????????self.hidden_dim?=?hidden_dim
????????self.lstm?=?nn.LSTM(in_dim?hidden_dim?n_layer?batch_first=True)
????????self.classifier?=?nn.Linear(hidden_dim?n_class)
????def?forward(self?x):
????????out?_?=?self.lstm(x)
????????out?=?out[:?-1?:]
????????out?=?self.classifier(out)
????????return?out
model?=?Rnn(28?128?2?10)??#?圖片大小是28x28
#?定義loss和optimizer
criterion?=?nn.CrossEntropyLoss()
optimizer?=?optim.Adam(model.parameters()?lr=learning_rate)
#?開始訓練
for?epoch?in?range(num_epoches):
????running_loss?=?0.0
????running_acc?=?0.0
????for?i?data?in?enumerate(train_loader?1):
????????img?label?=?data
????????img?=?img.squeeze(1)
????????if?torch.cuda.is_available():
????????????img?=?img.cuda()
????????????label?=?label.cuda()
????????else:
????????????img?=?Variable(img)
?
評論
共有 條評論