資源簡介
人工智能算法實現mnist手寫數字識別
代碼片段和文件信息
import?tensorflow?as?tf
import?os
import?numpy?as?np
from?matplotlib?import?pyplot?as?plt
from?tensorflow.keras.layers?import?Conv2D?BatchNormalization?Activation?MaxPool2D?Dropout?Flatten?Dense
from?tensorflow.keras?import?Model
np.set_printoptions(threshold=np.inf)
mnist?=?tf.keras.datasets.mnist
#fashion?=?tf.keras.datasets.fashion_mnist
(x_train?y_train)?(x_test?y_test)?=?mnist.load_data()
x_train?x_test?=?x_train?/?255.0?x_test?/?255.0
print(“x_train.shape“?x_train.shape)
x_train?=?x_train.reshape(x_train.shape[0]?28?28?1)??#?給數據增加一個維度,使數據和網絡結構匹配
x_test?=?x_test.reshape(x_test.shape[0]?28?28?1)
print(“x_train.shape“?x_train.shape)
class?baseline(Model):
????def?__init__(self):
????????super(baseline?self).__init__()
????????self.c1?=?Conv2D(filters=6?kernel_size=(5?5)?padding=‘same‘)??#?卷積層
????????self.b1?=?BatchNormalization()??#?BN層
????????self.a1?=?Activation(‘relu‘)??#?激活層
????????self.p1?=?MaxPool2D(pool_size=(2?2)?strides=2?padding=‘same‘)??#?池化層
????????self.d1?=?Dropout(0.2)??#?dropout層
????????self.flatten?=?Flatten()
????????self.f1?=?Dense(128?activation=‘relu‘)
????????self.d2?=?Dropout(0.2)
????????self.f2?=?Dense(10?activation=‘softmax‘)
????def?call(self?x):
????????x?=?self.c1(x)
????????x?=?self.b1(x)
????????x?=?self.a1(x)
????????x?=?self.p1(x)
????????x?=?self.d1(x)
????????x?=?self.flatten(x)
????????x?=?self.f1(x)
????????x?=?self.d2(x)
????????y?=?self.f2(x)
????????return?y
model?=?baseline()
model.compile(optimizer=‘adam‘
??????????????loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
??????????????metrics=[‘sparse_categorical_accuracy‘])
checkpoint_save_path?=?“./checkpoint/baseline.ckpt“
if?os.path.exists(checkpoint_save_path?+?‘.index‘):
????print(‘-------------load?the?model-----------------‘)
????model.load_weights(checkpoint_save_path)
cp_callback?=?tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path
?????????????????????????????????????????????????save_weights_only=True
?????????????????????????????????????????????????save_best_only=True)
history?=?model.fit(x_train?y_train?batch_size=32?epochs=5?validation_data=(x_test?y_test)?validation_freq=1
????????????????????callbacks=[cp_callback])
model.summary()
#?print(model.trainable_variables)
file?=?open(‘./weights.txt‘?‘w‘)
for?v?in?model.trainable_variables:
????file.write(str(v.name)?+?‘\n‘)
????file.write(str(v.shape)?+?‘\n‘)
????file.write(str(v.numpy())?+?‘\n‘)
file.close()
###############################################????show???###############################################
#?顯示訓練集和驗證集的acc和loss曲線
acc?=?history.history[‘sparse_categorical_accuracy‘]
val_acc?=?history.history[‘val_sparse_categorical_accuracy‘]
loss?=?history.history[‘loss‘]
val_loss?=?history.history[‘val_loss‘]
plt.subplot(1?2?1)
plt.plot(acc?label=‘Training?Accuracy‘)
plt.plot(val_acc?label=‘Validation?Accuracy
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件???????3330??2020-07-11?03:54??mnist實驗報告\ML_MNIST\CNN.py
?????文件???????1495??2020-07-11?04:48??mnist實驗報告\ML_MNIST\FCN.py
?????文件???????1843??2020-07-11?05:38??mnist實驗報告\ML_MNIST\RNN.py
?????文件?????256512??2020-11-13?22:11??mnist實驗報告\mnist實驗報告.doc
?????目錄??????????0??2020-07-11?06:00??mnist實驗報告\ML_MNIST
?????目錄??????????0??2020-11-13?22:11??mnist實驗報告
-----------?---------??----------?-----??----
???????????????263180????????????????????6
- 上一篇:用Python學微積分.azw3
- 下一篇:小說閱讀項目源碼(附數據庫腳本)
評論
共有 條評論