-
大小: 4KB文件類型: .py金幣: 1下載: 0 次發布日期: 2021-05-12
- 語言: Python
- 標簽: CNN??TensorFlow??深度學習??
資源簡介
CNN卷積神經網絡tensorflow代碼,使用MNIST數據集,安裝好python和TensorFlow可直接運行
代碼片段和文件信息
from?__future__?import?print_function
import?tensorflow?as?tf
#?Import?MNIST?data
from?tensorflow.examples.tutorials.mnist?import?input_data
mnist?=?input_data.read_data_sets(“/tmp/data/“?one_hot=True)
#?Parameters
learning_rate?=?0.001
training_iters?=?200000
batch_size?=?128
display_step?=?10
#?Network?Parameters
n_input?=?784?#?MNIST?data?input?(img?shape:?28*28)
n_classes?=?10?#?MNIST?total?classes?(0-9?digits)
dropout?=?0.75?#?Dropout?probability?to?keep?units
#?tf?Graph?input
x?=?tf.placeholder(tf.float32?[None?n_input])
y?=?tf.placeholder(tf.float32?[None?n_classes])
keep_prob?=?tf.placeholder(tf.float32)?#dropout?(keep?probability)
#?Create?some?wrappers?for?simplicity
def?conv2d(x?W?b?strides=1):
????#?Conv2D?wrapper?with?bias?and?relu?activation
????x?=?tf.nn.conv2d(x?W?strides=[1?strides?strides?1]?padding=‘SAME‘)
????x?=?tf.nn.bias_add(x?b)
????return?tf.nn.relu(x)
def?maxpool2d(x?k=2):
????#?MaxPool2D?wrapper
????return?tf.nn.max_pool(x?ksize=[1?k?k?1]?strides=[1?k?k?1]
??????????????????????????padding=‘SAME‘)
#?Create?model
def?conv_net(x?weights?biases?dropout):
????#?Reshape?input?picture
????x?=?tf.reshape(x?shape=[-1?28?28?1])
????#?Convolution?layer
????conv1?=?conv2d(x?weights[‘wc1‘]?biases[‘bc1‘])
????#?Max?Pooling?(down-sampling)
????conv1?=?maxpool2d(conv1?k=2)
????#?Convolution?layer
????conv2?=?conv2d(conv1?weights[‘wc2‘]?biases[‘bc2‘])
????#?Max?Pooling?(down-sampling)
????conv2?=?maxpool2d(conv2?k=2)
????#?Fully?connected?layer
????#?Reshape?conv2?output?to?fit?fully?connected?layer?input
????fc1?=?tf.reshape(conv2?[-1?weights[‘wd1‘].get_shape().as_list()[0]])
????fc1?=?tf.add(tf.matmul(fc1?weights[‘wd1‘])?biases[‘bd1‘])
????fc1?=?tf.nn.relu(fc1)
????#?Apply?Dropout
????fc1?=?tf.nn.dropout(fc1?dropout)
????#?Output?class?prediction
????out?=?tf.add(tf.matmul(fc1?weights[‘out‘])?biases[‘out‘])
????return?out
#?Store?layers?weight?&?bias
weights?=?{
????#?5x5?conv?1?input?32?outputs
????‘wc1‘:?tf.Variable(tf.random_normal([5?5?1?32]))
????#?5x5?conv?32?inputs?64?outputs
????‘wc2‘:?tf.Variable(tf.random_normal
評論
共有 條評論