資源簡介
采用全連接神經網絡對012345的手勢圖片進行識別,基于Pytorch編寫,附h5格式數據集
代碼片段和文件信息
import?torch?as?t
import?h5py
import?numpy?as?np
import?torch.nn?as?nn
import?matplotlib.pyplot?as?plt
import?os?warnings
os.environ[‘KMP_DUPLICATE_LIB_OK‘]?=?‘True‘
warnings.filterwarnings(‘ignore‘)
class?Net(nn.Module):
????def?__init__(self?n_inputs?n_hidden_1?n_hidden_2?n_outputs):
????????super(Net?self).__init__()
????????self.fc?=?nn.Sequential(
????????????nn.Linear(n_inputs?n_hidden_1)?nn.BatchNorm1d(n_hidden_1)?nn.ReLU()
????????????nn.Linear(n_hidden_1?n_hidden_2)?nn.BatchNorm1d(n_hidden_2)?nn.ReLU()
????????????nn.Linear(n_hidden_2?n_outputs)
????????)
????def?forward(self?x):
????????output?=?self.fc(x)
????????return?output
train_dataset?=?h5py.File(‘train_signs.h5‘?“r“)
train_set_x_orig?=?t.from_numpy(np.array(train_dataset[“train_set_x“]))
train_set_y?=?t.tensor(np.array(train_dataset[“train_set_y“]))
test_dataset?=?h5py.File(‘test_signs.h5‘?“r“)
test_set_x_orig?=?t.from_numpy(np.array(test_dataset[“test_set_x“]))
test_set_y?=?t.tensor(np.array(test_dataset[“test_set_y“]))
print(train_set_x_orig.shape)?????????#?(108064643)
print(test_set_x_orig.shape)??????????#?(12064643)
print(train_set_y.shape)?????????#?(1080)
print(test_set_y.shape)??????????#?(120)
classes?=?np.array(test_dataset[“list_classes“])
print(classes)?????????????#?[0?1?2?3?4?5]
fig?=?plt.figure()
for?i?in?range(6):
????plt.subplot(2?3?i+1)
????plt.imshow(train_set_x_orig[i])
????plt.title(“Labels:?{}“.format(classes[int(train_set_y[i])]))
????plt.xticks([])
????plt.yticks([])
????pass
plt.show()
num_train?=?train_set_x_orig.shape[0]????????????#?1080
num_test?=?test_set_x_orig.shape[0]??????????????#?120
print(“訓練集包含?%d?張圖片;?測試集包含?%d?張圖片。“?%?(num_train?num_test))
train_set_x?=?train_set_x_orig.reshape(num_train?-1)/255.0
test_set_x?=?test_set_x_orig.reshape(num_test?-1)/255.0
print(“訓練集圖片尺寸變為:?“?train_set_x.shape)???????????#?(108064643)?-->?(108012288)
print(“測試集圖片尺寸變為:?“?test_set_x.shape)????????????#?(12064643)?-->?(12012288)
model?=?Net(12288?200?50?6)
print(model)
print(“model包含{}個參數。“.format(sum(x.numel()?for?x?in?model.parameters())))
cost_func?=?nn.CrossEntropyLoss()
optimizer?=?t.optim.SGD(model.parameters()?lr=0.0005?momentum=0.9)
epochs?=?200
for?epoch?in?range(epochs):
????model.train()
????y_out?=?model(train_set_x)
????train_loss?=?cost_func(y_out?train_set_y)
????optimizer.zero_grad()
????train_loss.backward()
????optimizer.step()
????with?t.no_grad():
????????_?y_pred?=?y_out.max(1)
????????num_correct_train?=?(y_pred?==?train_set_y).sum().item()
????????acc_rate_train?=?(num_correct_train?/?num_train)?*?100.0
????????model.eval()
????????y_out?=?model(test_set_x)
????????_?y_pred?=?y_out.max(1)
????????num_correct_test?=?(y_pred?==?test_set_y).sum().item()
????????acc_rate_test?=?num_correct_test?*?100.0?/?num_test
????????print(“世代數:?%d?訓練集正確率:?%.1f%%?測試集正確率:?%.1f%%“?%?(epoch+1?acc_rate_train
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件????????3140??2020-08-12?10:09??012345手勢識別-FCNet\Fully_connected_Net.py
?????文件?????1477712??2019-07-09?23:16??012345手勢識別-FCNet\test_signs.h5
?????文件????13281872??2019-07-09?23:16??012345手勢識別-FCNet\train_signs.h5
?????目錄???????????0??2020-09-13?10:55??012345手勢識別-FCNet\
- 上一篇:貓-非貓圖二分類識別
- 下一篇:python繪制 彩色蜂蜜窩(基于turtle)
評論
共有 條評論