-
大小: 2.19MB文件類型: .rar金幣: 2下載: 0 次發(fā)布日期: 2023-08-13
- 語言: 其他
- 標(biāo)簽: RtFrecords??
資源簡介
tensorflow下 自制rfrecords數(shù)據(jù)集采用one-hot編碼做圖像分類源碼

代碼片段和文件信息
#?-*-?coding:?utf-8?-*-
“““
Created?on?Sat?Feb?23?23:21:44?2019
@author:?Administrator
“““
import?tensorflow?as?tf
import?numpy?as?np
from?sklearn.preprocessing?import?OneHotEncoder
from?RTFrcord_read_data?import?read_and_decode
############################################################################################
height=100
weight=100
#############################################################################################
batch_size=432
?
#定義初始化權(quán)重和偏置函數(shù)
def?weight_variable(shape):
????return(tf.Variable(tf.random_normal(shapestddev=0.01)))
def?bias_variable(shape):
????return(tf.Variable(tf.constant(0.1shape=shape)))
#定義輸入數(shù)據(jù)和dropout占位符
X=tf.placeholder(tf.float32[batch_sizeheight?weight3])
y_=tf.placeholder(tf.float32[batch_size8])
keep_pro=tf.placeholder(tf.float32)
?
#搭建網(wǎng)絡(luò)
def?model(Xkeep_pro):
????w1=weight_variable([55332])
????b1=bias_variable([32])
????conv1=tf.nn.relu(tf.nn.conv2d(Xw1strides=[1111]padding=‘SAME‘)+b1)
????pool1=tf.nn.max_pool(conv1ksize=[1441]strides=[1441]padding=‘SAME‘)
????
????w2=weight_variable([553264])
????b2=bias_variable([64])
????conv2=tf.nn.relu(tf.nn.conv2d(pool1w2strides=[1111]padding=‘SAME‘)+b2)
????pool2=tf.nn.max_pool(conv2ksize=[1441]strides=[1441]padding=‘SAME‘)?
????tensor=tf.reshape(pool2[batch_size-1])
????dim=tensor.get_shape()[1].value
????w3=weight_variable([dim1024])
????b3=bias_variable([1024])
????fc1=tf.nn.relu(tf.matmul(tensorw3)+b3)
????h_fc1=tf.nn.dropout(fc1keep_pro)
????w4=weight_variable([10248])
????b4=bias_variable([8])
????y_conv=tf.nn.softmax(tf.matmul(h_fc1w4)+b4)
????return(y_conv)
?
#定義網(wǎng)絡(luò),并設(shè)置損失函數(shù)和訓(xùn)練器
y_conv=model(Xkeep_pro)
cost=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv)reduction_indices=[1]))
train_step=tf.train.AdamOptimizer(0.001).minimize(cost)
#計算準(zhǔn)確率
correct_prediction=tf.equal(tf.argmax(y_conv1)tf.argmax(y_1))
accuracy=tf.reduce_mean(tf.cast(correct_predictiontf.float32))
#讀取tfrecords數(shù)據(jù)
imagelabel=read_and_decode(“train1.tfrecords“)
#定義會話,并開始訓(xùn)練
with?tf.Session()?as?sess:
????tf.global_variables_initializer().run()
????#定義多線程
????coord=tf.train.Coordinator()
????threads=tf.train.start_queue_runners(coord=coord)
????#定義訓(xùn)練圖像和標(biāo)簽
????example=np.zeros((batch_sizeheight?weight3))
????l=np.zeros((batch_size1))
????try:
????????#將數(shù)據(jù)存入example和l并將轉(zhuǎn)化成one_hot形式
????????for?epoch?in?range(batch_size):
????????????example[epoch]l[epoch]=sess.run([imagelabel])
????????print(l)??
???????
????????enc=OneHotEncoder()
????????l=enc.fit_transform(l)
????????l=l.toarray()
????????print(l)
????????for?i?in?range(100):
????????????#開始訓(xùn)練
????????????sess.run(train_stepfeed_dict={X:exampley_:lkeep_pro:0.5})
????????????if?i%10==0:
????????????????print(‘train?step‘‘%04d?‘?%(i+1)‘Accuracy=‘sess.run(accuracyfeed_dict={X:exampley_:lkeep_pro:0.5}))
????except?tf.errors.OutOfRangeError:
????????print(‘done!‘)
????finally:
???
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件???????3234??2019-02-23?23:47??RTFrcords\data_classification.py
?????文件???????1743??2019-02-23?23:47??RTFrcords\RTFrcord_read_data.py
?????文件???????1819??2019-02-23?23:08??RTFrcords\RTFrcord_save_data.py
?????文件???13016413??2019-02-23?23:09??RTFrcords\train1.tfrecords
?????文件????????893??2019-02-23?23:27??RTFrcords\__pycache__\RTFrcord_read_data.cpython-36.pyc
?????目錄??????????0??2019-02-23?23:27??RTFrcords\__pycache__
?????目錄??????????0??2019-02-23?23:27??RTFrcords
-----------?---------??----------?-----??----
?????????????13024102????????????????????7
評論
共有 條評論