-
大小: 4KB文件類型: .py金幣: 1下載: 0 次發布日期: 2021-06-02
- 語言: Python
- 標簽: tensorflow2.??python??mnist??
資源簡介
基于python3.7版本的tensorflow2.0實現mnist手寫數字識別代碼
代碼片段和文件信息
import?warnings
warnings.filterwarnings(‘ignore‘)
import?tensorflow?as?tf
from?tensorflow?import?keras
from?tensorflow.keras?import?layers?datasets?Sequential?optimizers?metrics
(x_train?y_train)?(x_test?y_test)?=?tf.keras.datasets.mnist.load_data()
#?數據預處理
def?train_preprocess(x_train?y_train):
????x_train?=?tf.cast(x?=?x_train?dtype?=?tf.float32)?/?255.
????y_train?=?tf.cast(x?=?y_train?dtype?=?tf.int32)
????y_train?=?tf.one_hot(indices?=?y_train?depth?=?10)
????return?x_train?y_train
def?test_preprocess(x_test?y_test):
????x_test?=?tf.cast(x?=?x_test?dtype?=?tf.float32)?/?255.
????y_test?=?tf.cast(x?=?y_test?dtype?=?tf.int32)
????return?x_test?y_test
train_db?=?tf.data.Dataset.from_tensor_slices(tensors=(x_train?y_train))
train_db?=?train_db.map(map_func=train_preprocess).shuffle(buffer_size=1000).batch(batch_size=128)
test_db?=?tf.data.Dataset.from_tensor_slices(tensors=(x_test?y_test))
test_db?=?test_db.map(map_func=test_preprocess).batch(batch_size=128)
#?建立網絡模型
model?=?tf.keras.Sequential([
????tf.keras.layers.Dense(units=512?activation=tf.nn.relu)
????tf.keras.layers.Dense(units=256?activation=tf.nn.relu)
????tf.keras.layers.Dense(units=128?activation=tf.nn.relu)
????tf.keras.layers.Dense(units=32?activation=tf.nn.relu)
????tf.keras.layers.Dense(units=10)
])
model.build(input_shape=[None?28?*?28])
model.summary()
optimizer?=?tf.keras.optimizers.Adam(learning_rate?=?1e-4)
def?main():
????for?epoch?in?range(100):
????????for?step?(x_train?y_train)?in?enumerate(train_db):
????????????x_train?=?tf.reshape(tensor?=?x_train?shape?=?[-1?28?*?28])
????????????with?tf.GradientTape()?as?tape:
????????????????logits?=?model(x_train)
????????????????loss?=?tf.losses.categorical_crossentropy(y_true?=?y_train?y_pr
- 上一篇:中間代碼生成代碼中綴表達式轉換為四元式
- 下一篇:python QQ第三方登陸
評論
共有 條評論