資源簡(jiǎn)介
使用LSTM預(yù)測(cè)鐵路客運(yùn)時(shí)間序列,代碼在tensorflow 1.3上調(diào)試運(yùn)行通過,解決了網(wǎng)絡(luò)相關(guān)資源的一些坑(很多資源不是tensorflow的版本太舊就是數(shù)據(jù)與鐵路客運(yùn)數(shù)據(jù)不太匹配,難以直接運(yùn)行測(cè)試),不需要建立train模型和test模型(被繞了半天,這種測(cè)試應(yīng)該毫無必要),可對(duì)不同的超參數(shù)組合比較訓(xùn)練誤差和學(xué)習(xí)誤差!
代碼片段和文件信息
#?coding:?utf-8
“““
2017-8-20
使用最基本的LSTM神經(jīng)網(wǎng)預(yù)測(cè)客運(yùn)數(shù)據(jù)
HParams:?模型的超參數(shù)
TS_LSTM:時(shí)間序列訓(xùn)練的LSTM模型
train_test:訓(xùn)練和測(cè)試函數(shù)
@?蔣秋華
“““
import?numpy?as?np??
import?tensorflow?as?tf??
import?matplotlib.pyplot?as?plt??
import?pandas?as?pd??
from?collections?import?namedtuple???
“““
簡(jiǎn)單的時(shí)間序列LSTM模型!
“““
HParams?=?namedtuple(‘HParams‘
?????????????????????‘seq_size?hidden_size?learning_rate‘)
class?TS_LSTM(object):
????def?__init__(self?hps):
????????self._X?=?X?=?tf.placeholder(tf.float32?[None?hps.seq_size?1])??
????????self._Y?=?Y?=?tf.placeholder(tf.float32?[None?hps.seq_size])??????
????????W?=?tf.Variable(tf.random_normal([hps.hidden_size?1])?name=‘W‘)??
????????b?=?tf.Variable(tf.random_normal([1])?name=‘b‘)??
????????lstm_cell?=?tf.nn.rnn_cell.BasicLSTMCell(hps.hidden_size)??#測(cè)試cost?1.3809
????????#?lstm_cell?=?tf.nn.rnn_cell.BasicLSTMCell(config.hidden_sizeactivation=tf.nn.relu)???#測(cè)試cost?261.078,性能太差
????????outputs?states?=?tf.nn.dynamic_rnn(lstm_cell?X?dtype=tf.float32)??
????????W_repeated?=?tf.tile(tf.expand_dims(W?0)?[tf.shape(X)[0]?1?1])??
????????output?=?tf.nn.xw_plus_b(outputs?W_repeated?b)??
????????self._output?=?output?=?tf.squeeze(output)??
????????self._cost?=?cost?=?tf.reduce_mean(tf.square(output?-?Y))??
????????self._train_op?=?tf.train.AdamOptimizer(hps.learning_rate).minimize(cost)??
????????
????@property
????def?X(self):
????????return?self._X
????@property
????def?Y(self):
????????return?self._Y???
????@property
????def?cost(self):
????????return?self._cost
????
????@property
????def?output(self):
????????return?self._output
????@property
????def?train_op(self):
????????return?self._train_op
def?train_test(hps?data):
?????#訓(xùn)練數(shù)據(jù)準(zhǔn)備
????train_data_len?=?len(data)*2//3
????train_x?train_y?=?[]?[]??
????for?i?in?range(train_data_len?-?hps.seq_size?-?1):??
????????train_x.append(np.expand_dims(data[i?:?i?+?hps.seq_size]?axis=1).tolist())??
????????train_y.append(data[i?+?1?:?i?+?hps.seq_size?+?1].tolist())??
????#測(cè)試數(shù)據(jù)準(zhǔn)備????
????test_data_len?=?len(data)//3
????test_x?test_y?=?[]?[]??
????for?i?in?range(train_data_len
???????????????????train_data_len+test_data_len?-?hps.seq_size?-?1):??
????????test_x.append(np.expand_dims(data[i?:?i?+?hps.seq_size]?axis=1).tolist())??
????????test_y.append(data[i?+?1?:?i?+?hps.seq_size?+?1].tolist())??
?????
????with?tf.Graph().as_default()?tf.Session()?as?sess:??
????????with?tf.variable_scope(‘model‘reuse=None):
????????????m_train?=?TS_LSTM(hps)
#????????with?tf.variable_scope(‘model‘reuse=True):????#建立的測(cè)試模型沒有什么用!
#????????????m_test?=?TS_LSTM(Falseconfig)???????????
????????????????????????
????????#訓(xùn)練
????????tf.global_variables_initializer().run()
????????for?step?in?range(20000):??
????????????_?train_cost?=?sess.run([m_train.train_op?m_train.cost]?
??????????????????????????????feed_dict={m_train.X:?train_x?m_train.Y:?train_y})??
????????????????
????????#預(yù)測(cè)?
????????test_cost?output?=?sess.run([m_train.cost?m_train.output]
??????????????????feed_dict={m_tra
評(píng)論
共有 條評(píng)論