資源簡介
Time series forecasting using LSTM.
代碼片段和文件信息
from?math?import?sqrt
from?numpy?import?concatenate
from?matplotlib?import?pyplot
from?pandas?import?read_csv
from?pandas?import?Dataframe
from?pandas?import?concat
import?pandas?as??pd
from?sklearn.preprocessing?import?MinMaxScaler
from?sklearn.preprocessing?import?LabelEncoder
from?keras.models?import?Sequential
from?keras.layers?import?Dense
from?keras.layers?import?LSTM
from?sklearn.metrics?import?mean_absolute_error
from?sklearn.metrics?import?mean_squared_error
#?convert?series?to?supervised?learning
def?series_to_supervised(data?n_in=1?n_out=1?dropnan=True):
????n_vars?=?1?if?type(data)?is?list?else?data.shape[1]
????df?=?Dataframe(data)
????cols?names?=?list()?list()
????#?input?sequence?(t-n?...?t-1)
????for?i?in?range(n_in?0?-1):
????????cols.append(df.shift(i))
????????names?+=?[(‘var%d(t-%d)‘?%?(j?+?1?i))?for?j?in?range(n_vars)]
????#?forecast?sequence?(t?t+1?...?t+n)
????for?i?in?range(0?n_out):
????????cols.append(df.shift(-i))
????????if?i?==?0:
????????????names?+=?[(‘var%d(t)‘?%?(j?+?1))?for?j?in?range(n_vars)]
????????else:
????????????names?+=?[(‘var%d(t+%d)‘?%?(j?+?1?i))?for?j?in?range(n_vars)]
????#?put?it?all?together
????agg?=?concat(cols?axis=1)
????agg.columns?=?names
????#?drop?rows?with?NaN?values
????if?dropnan:
????????agg.dropna(inplace=True)
????return?agg
def?plot_results(predicted_data?true_data):
????fig?=?pyplot.figure(facecolor=‘white‘)
????ax?=?fig.add_subplot(111)
????ax.plot(true_data?label=‘True?Data‘)
????pyplot.plot(predicted_data?label=‘Prediction‘)
????pyplot.legend()
????pyplot.show()
#?load?dataset
dataset?=?read_csv(‘SP500_data.csv‘?header=0?index_col=0)
values?=?dataset.values
#?ensure?all?data?is?float
values?=?values.astype(‘float32‘)
#?normalize?features
scaler?=?MinMaxScaler(feature_range=(0?1))
scaled?=?scaler.fit_transform(values)
time_step?=?14
#?frame?as?supervised?learning
reframed?=?series_to_supervised(scaledtime_steptime_step)
print(reframed.head())
input(‘enter‘)
#?split?into?train?and?test?sets
values?=?reframed.values
print(values)
print(‘len‘len(values))
n_train_days?=2983?#?train?is?time_step0%
train?=?values[:n_train_days?:]
test?=?values[n_train_days:?:]
#?split?into?input?and?out
評論
共有 條評論