資源簡介
MOOC上北大老師講的《Tensorflow筆記》里的手寫體識別代碼,初學者食用,無誤版。
代碼片段和文件信息
#建“mnist_forward.py”
#?mnist_forward.py
import?tensorflow?as?tf
INPUT_NODE?=?784
OUTPUT_NODE?=?10
layer1_NODE?=?500
def?get_weight(shape?regularizer):
????w?=?tf.Variable(tf.truncated_normal(shape?stddev=0.1))
????if?regularizer?!=?None:
????????tf.add_to_collection(‘losses‘?tf.contrib.layers.l2_regularizer(regularizer)(w))#把正則化加入到losses里面
????return?w
def?get_bias(shape):
????b?=?tf.Variable(tf.zeros(shape))
????return?b
def?forward(x?regularizer):
????w1?=?get_weight([INPUT_NODE?layer1_NODE]?regularizer)
????b1?=?get_bias([layer1_NODE])
????y1?=?tf.nn.relu(tf.matmul(x?w1)?+?b1)
????w2?=?get_weight([layer1_NODE?OUTPUT_NODE]?regularizer)
????b2?=?get_bias([OUTPUT_NODE])
????y?=?tf.matmul(y1?w2)?+?b2
????return?y
#從此處重新建一個“mnist_backward.py“
#?mnist_backward.py
import?tensorflow?as?tf
from?tensorflow.examples.tutorials.mnist?import?input_data
import?mnist_forward
import?os
BATCH_SIZE?=?200
LEARNING_RATE_base?=?0.1
LEARNING_RATE_DECAY?=?0.99
REGULARIZER?=?0.0001
STEPS?=?50000
MOVING_AVERAGE_DECAY?=?0.99
MODEL_SAVE_PATH?=?“./model/“
MODEL_NAME?=?“mnist_model“
def?backward(mnist):
????x?=?tf.placeholder(tf.float32?[None?mnist_forward.INPUT_NODE])
????y_?=?tf.placeholder(tf.float32?[None?mnist_forward.OUTPUT_NODE])
????y?=?mnist_forward.forward(x?REGULARIZER)
????global_step?=?tf.Variable(0?trainable=False)
????ce?=?tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y?labels=tf.argmax(y_?1))
????cem?=?tf.reduce_mean(ce)
????loss?=?cem?+?tf.add_n(tf.get_collection(“losses“))
????learning_rate?=?tf.train.exponential_decay(
????????LEARNING_RATE_base
????????global_step
????????mnist.train.num_examples?/?BATCH_SIZE
????????LEARNING_RATE_DECAY
????????staircase=True)
????train_step?=?tf.train.GradientDescentOptimizer(learning_rate).minimize(loss?global_step=global_step)
????ema?=?tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY?global_step)
????ema_op?=?ema.apply(tf.trainable_variables())
????with?tf.control_dependencies([train_step?ema_op]):
????????train_op?=?tf.no_op(name=“train“)
????saver?=?tf.train.Saver()
????with?tf.Session()?as?sess:
????????init_op?=?tf.global_variables_initializer()
????????sess.run(init_op)
????????for?i?in?range(STEPS):
????????????xs
- 上一篇:某網Python3.6+電商實戰+Vue+Django
- 下一篇:文本查重系統
評論
共有 條評論