-
大小: 3KB文件類型: .rar金幣: 2下載: 0 次發(fā)布日期: 2021-06-11
- 語言: 其他
- 標簽: 鳶尾花??softmax??tensorflow??實例??
資源簡介
鳶尾花經(jīng)典案例,采用 softmax 分類,使用tensorflow實現(xiàn)

代碼片段和文件信息
#?-*-?coding:utf-8?-*-
#tensorflow?1.2?python?3.6
import?tensorflow?as?tf#導入TensorFlow庫
import?os#導入OS庫
W?=?tf.Variable(tf.zeros([4?3])?name=“weights“)#變量權值,矩陣,每個特征權值列對應一個輸出類別
b?=?tf.Variable(tf.zeros([3]?name=“bias“))#模型偏置,每個偏置對應一個輸出類別
def?combine_inputs(X):#輸入值合并
????print(“function:?combine_inputs“)
????return?tf.matmul(X?W)?+?b
def?inference(X):#計算返回推斷模型輸出(數(shù)據(jù)X)
????print(“function:?inference“)
????return?tf.nn.softmax(combine_inputs(X))#調(diào)用softmax分類函數(shù)
def?loss(X?Y):#計算損失(訓練數(shù)據(jù)X及期望輸出Y)
????print(“function:?loss“)
????return?tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=combine_inputs(X)?labels=Y))#求平均值,針對每個樣本只對應單個類別優(yōu)化
def?read_csv(batch_size?file_name?record_defaults):#從csv文件讀取數(shù)據(jù),加載解析,創(chuàng)建批次讀取張量多行數(shù)據(jù)
????filename_queue?=?tf.train.string_input_producer([os.path.join(os.getcwd()?file_name)])
????reader?=?tf.TextLineReader(skip_header_lines=1)
????key?value?=?reader.read(filename_queue)
????decoded?=?tf.decode_csv(value?record_defaults=record_defaults)#字符串(文本行)轉(zhuǎn)換到指定默認值張量列元組,為每列設置數(shù)據(jù)類型
????return?tf.train.shuffle_batch(decoded?batch_size=batch_size?capacity=batch_size?*?50?min_after_dequeue=batch_size)#讀取文件,加載張量batch_size行
def?inputs():#讀取或生成訓練數(shù)據(jù)X及期望輸出Y
????print(“function:?inputs“)
????#數(shù)據(jù)來源:https://archive.ics.uci.edu/ml/datasets/Iris
????#iris.data改為iris.csv,增加sepal_length?sepal_width?petal_length?petal_width?label字段行首行
????sepal_length?sepal_width?petal_length?petal_width?label?=\
????????read_csv(100?“iris.csv“?[[0.0]?[0.0]?[0.0]?[0.0]?[““]])
????#轉(zhuǎn)換屬性數(shù)據(jù)
????label_number?=?tf.to_int32(tf.argmax(tf.to_int32(tf.stack([
????????tf.equal(label?[“Iris-setosa“])
????????tf.equal(label?[“Iris-versicolor“])
????????tf.equal(label?[“Iris-virginica“])
????]))?0))#將類名稱轉(zhuǎn)抽象為從0開始的類別索引
????features?=?tf.transpose(tf.stack([sepal_length?sepal_width?petal_length?petal_width]))#特征裝入矩陣,轉(zhuǎn)置,每行一樣本,每列一特征
????return?features?label_number
def?train(total_loss):#訓練或調(diào)整模型參數(shù)(計算總損失)
????print(“function:?train“)
????learning_rate?=?0.01
????return?tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)
def?evaluate(sess?X?Y):#評估訓練模型
????print(“function:?evaluate“)
????predicted?=?tf.cast(tf.argmax(inference(X)?1)?tf.int32)#選擇預測輸出值最大概率類別
????print(sess.run(tf.reduce_mean(tf.cast(tf.equal(predicted?Y)?tf.float32))))#統(tǒng)計所有正確預測樣本數(shù),除以批次樣本總數(shù),得到正確預測百分比
with?tf.Session()?as?sess:#會話對象啟動數(shù)據(jù)流圖,搭建流程
????print(“Session:?start“)
????tf.global_variables_initializer().run()
????X?Y?=?inputs()
????total_loss?=?loss(X?Y)
????train_op?=?train(total_loss)
????coord?=?tf.train.Coordinator()
????threads?=?tf.train.start_queue_runners(sess=sess?coord=coord)
????training_steps?=?1000#實際訓練迭代次數(shù)
????for?step?in?range(training_steps):#實際訓練閉環(huán)
????????sess.run([train_op])
????????if?step?%?10?==?0:#查看訓練過程損失遞減
????????????print(str(step)+?“?loss:?“?sess.run([total_loss]))
????print(str(training_steps)?+?“?final?loss:?“?sess.run([total_loss]))
????evaluate(sess?X?Y
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件???????3765??2017-11-23?10:31??softmax_iris.py
?????文件???????4619??2017-11-23?10:24??iris.csv
-----------?---------??----------?-----??----
?????????????????8384????????????????????2
評論
共有 條評論