資源簡介
利用keras使用mnist數據集進行訓練,分類,測試。可以得到較好的結果(包含mnist數據集),親測有效,自行編寫,共同學習。

代碼片段和文件信息
from?keras?import?backend?as?k
from?keras.models?import?Sequential
from?keras.layers.convolutional?import?Conv2D?MaxPooling2D
#?from?keras.layers.convolutional?import?MaxPooling2D
from?keras.layers.core?import?Activation
from?keras.layers.core?import?Flatten
from?keras.layers.core?import?Dense
from?keras.datasets?import?mnist
#?from?mydata?import?load_data
from?keras.utils?import?np_utils
from?keras.utils.vis_utils?import?model_to_dot
from?IPython.display?import?SVG
from?keras.optimizers?import?SGD?RMSprop?Adam
import?numpy?as?np
import?matplotlib.pyplot?as?plt
import?matplotlib.image?as?processimage
#?定義網絡
class?LeNet:
????@staticmethod
????def?build(input_shape?classes):
????????model?=?Sequential()
????????#?CONV?=>?RELU?=>?POOL
????????model.add(Conv2D(20?kernel_size=5?padding=“same“
?????????????????????????input_shape=input_shape))
????????model.add(Activation(“relu“))
????????model.add(MaxPooling2D(pool_size=(2?2)?strides=(2?2)))
????????#?CONV?=>?RELU?=>?POOL
????????model.add(Conv2D(50?kernel_size=5?padding=“same“))
????????model.add(Activation(“relu“))
????????model.add(MaxPooling2D(pool_size=(2?2)?strides=(2?2)))
????????#?Flatten?=>?RELU?layers
????????model.add(Flatten())
????????model.add(Dense(500))
????????model.add(Activation(“relu“))
????????#?a?softmax?classifier
????????model.add(Dense(classes))
????????model.add(Activation(“softmax“))
????????return?model
#?SVG(model_to_dot(model).create(prog=‘dot‘?format=‘svg‘))
NB_EPOCH?=1
BATCH_SIZE?=?128
VERBOSE?=?2
OPTIMIZER?=?Adam()
VALIDATION_SPLIT?=?0.2
IMG_ROWS?IMG_COLS?=?2828
#?IMG_ROWS?IMG_COLS?=?512384
NB_CLASS?=?10
INPUT_SHAPE?=?(1?IMG_ROWSIMG_COLS)
#?混合并劃分訓練集和測試集
(X_train?y_train)?(X_test?y_test)=mnist.load_data()
#?(X_train?y_train)?(X_test?y_test)?=?load_data()
k.set_image_dim_ordering(“th“)
#?看成float類型并歸一化
X_train?=?X_train.astype(‘float32‘)
X_test?=?X_test.astype(‘float32‘)
X_train?/=?255
X_test?/=?255
#?需要使用形狀60k*[1*28*28]作為卷積網絡的輸入
X_train?=?X_train[:?np.newaxis?:?:]
X_test?=?X_test[:?np.newaxis?:?:]
print(‘X_train?shape:‘?X_train.shape)
print(X_train.shape[0]?‘?train?sample‘)
print(X_test.shape[0]‘test?sample‘)
#?將類向量轉換為二值類別矩陣
y_train?=?np_utils.to_categorical(y_train?NB_CLASS)
y_test?=?np_utils.to_categorical(y_test?NB_CLASS)
#?初始化優化器和模型
model?=?LeNet.build(input_shape=INPUT_SHAPE?classes=NB_CLASS)
model.compile(loss=“categorical_crossentropy“
??????????????????optimizer=OPTIMIZER?metrics=[“accuracy“])
history?=?model.fit(X_train?y_train?batch_size=BATCH_SIZE
????????????????????????epochs=NB_EPOCH?verbose=VERBOSE?validation_split=VALIDATION_SPLIT)
score?=?model.evaluate(X_testy_test?verbose=VERBOSE)
print(“test?score:“?score[0])
print(“test?accurary:“?score[1])
testrun?=?X_test[9999].reshape(1784)
testlabel?=?y_test[9999]
print(testrun)
print(testlabel)
plt.imshow(testrun.reshape([2828]))
plt.show()
pred?=?model.pr
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件????11490434??2019-07-04?14:37??mnist.npz
?????文件????????3753??2019-07-29?15:18??ttlenet.py
- 上一篇:SYC8P1228 SSOP28中文數據手冊V1.00
- 下一篇:CCS破解安裝包
評論
共有 條評論