-
大小: 8KB文件類型: .py金幣: 1下載: 1 次發(fā)布日期: 2021-06-18
- 語言: Python
- 標簽: 深度學(xué)習(xí)??
資源簡介
實現(xiàn)堆疊降噪自編碼器功能,以tensorflow中的mnist數(shù)據(jù)集為例,python2.7
代碼片段和文件信息
#coding:UTF-8
import?tensorflow?as?tf
import?numpy?as?np
from?tensorflow.examples.tutorials.mnist?import?input_data
#創(chuàng)建降噪自編碼器類
class?Denosing_AutoEncoder():
????def?__init__(selfn_hiddeninput_datacorruption_level=0.3):
????????self.W=None#輸入層到隱含層的權(quán)重
????????self.b=None#輸入層到隱含層的偏置
????????self.encode_r=None#隱含層的輸出5
????????self.layer_size=n_hidden#隱含層節(jié)點的個數(shù)
????????self.input_data=input_data#輸入樣本
????????self.keep_prob=1-corruption_level#特征保持不變的比例
????????self.W_eval=None#權(quán)重W的值
????????self.b_eval=None#偏置b的值
????#降噪自編碼器的訓(xùn)練
????def?fit(self):
????????#輸入層節(jié)點的個數(shù)
????????n_visible=(self.input_data).shape[1]
????????#輸入一張圖片用28*28=784的向量表示
????????X=tf.placeholder(“float“[Nonen_visible]name=‘X‘)
????????#用于將部分輸入數(shù)據(jù)置為0
????????mask=tf.placeholder(“float“[Nonen_visible]name=‘mask‘)
????????#創(chuàng)建權(quán)重和偏置
????????W_init_max=4*np.sqrt(6./(n_visible+self.layer_size))
????????W_init=tf.random_uniform(shape=[n_visibleself.layer_size]minval=-W_init_maxmaxval=W_init_max)
????????#編碼器
????????self.W=tf.Variable(W_initname=‘W‘)#784*500
????????self.b=tf.Variable(tf.zeros([self.layer_size])name=‘b‘)#隱含層的偏置
????????#解碼器
????????W_prime=tf.transpose(self.W)
????????b_prime=tf.Variable(tf.zeros([n_visible])name=‘b_prime‘)
????????tilde_X=mask*X#對輸入樣本加入噪聲
????????Y=tf.nn.sigmoid(tf.matmul(tilde_Xself.W)+self.b)#隱含層的輸出
????????Z=tf.nn.sigmoid(tf.matmul(YW_prime)+b_prime)#重構(gòu)輸出
????????cost=tf.reduce_mean(tf.pow(X-Z2))#均方誤差
????????train_op=tf.train.GradientDescentOptimizer(0.01).minimize(cost)#最小化均方誤差
????????trX=self.input_data
????????#開始訓(xùn)練
????????with?tf.Session()?as?sess:
????????????#初始化所有的參數(shù)
????????????tf.initialize_all_variables().run()
????????????for?i?in?range(30):
????????????????for?startend?in?zip(range(0len(trX)128)range(128len(trX+1)128)):#len(trX)矩陣秩,start=0128256...end=128256...
????????????????????input_=trX[start:end]#設(shè)置輸入每次輸入128個,分批操作
????????????????????mask_np=np.random.binomial(1self.keep_probinput_.shape)#設(shè)置mask基于二項分布(npsize),當(dāng)n=1時,為伯努利分布(0-1分布)
????????????????????#開始訓(xùn)練
????????????????????sess.run(train_opfeed_dict={X:input_mask:mask_np})
????????????????if?i%5.0==0:#每隔5次輸出一次mask=[11...1]時的loss
????????????????????mask_np=np.random.binomial(11trX.shape)#此時mask尺寸大小與原來輸入尺寸大小一致,因此是trX
????????????????????print(“l(fā)oss?function?at?step?%s?is?%s“%(isess.run(costfeed_dict={X:trXmask:mask_np})))
????????????#保存好輸入層到隱含層的參數(shù)
????????????self.W_eval=(self.W).eval()
????????????self.b_eval=(self.b).eval()
????????????mask_np=np.random.binomial(11trX.shape)
????????????self.encode_r=Y.eval({X:trXmask:mask_np})
????#取得降噪自編碼器的參數(shù)
????def?get_value(self):
????????return?self.W_evalself.b_evalself.encode_r
#創(chuàng)建堆疊降噪自編碼器類
class?Stacked_Denosing_AutoEncoder():
????def?__init__(selfhidden_listinput_data_trainX
?????????????????input_data_trainYinput_data_validX
?????????????????input_data_validYinput_data_testX
?????????????????input_data_testYc
評論
共有 條評論