資源簡介
切片循環(huán)神經網(wǎng)絡(Sliced recurrent neural networks,SRNN),在不改變循環(huán)單元的情況下,比RNN結構快135倍。
代碼片段和文件信息
‘‘‘
Author:?Zeping?Yu
Sliced?Recurrent?Neural?Network?(SRNN).?
SRNN?is?able?to?get?much?faster?speed?than?standard?RNN?by?slicing?the?sequences?into?many?subsequences.
This?work?is?accepted?by?COLING?2018.
The?code?is?written?in?keras?using?tensorflow?backend.?We?implement?the?SRNN(82)?here?and?Yelp?2013?dataset?is?used.
If?you?have?any?question?please?contact?me?at?zepingyu@foxmail.com.
‘‘‘
import?pandas?as?pd
import?numpy?as?np
from?keras.utils.np_utils?import?to_categorical
from?keras.preprocessing.text?import?Tokenizer?text_to_word_sequence
from?keras.preprocessing.sequence?import?pad_sequences
from?keras.models?import?Model
from?keras.layers?import?Input?embedding?GRU?TimeDistributed?Dense
#load?data
df?=?pd.read_csv(“yelp_2013.csv“)
#df?=?df.sample(5000)
Y?=?df.stars.values-1
Y?=?to_categorical(Ynum_classes=5)
X?=?df.text.values
#set?hyper?parameters
MAX_NUM_WORDS?=?30000
embedDING_DIM?=?200
VALIDATION_SPLIT?=?0.1
TEST_SPLIT=0.1
NUM_FILTERS?=?50
MAX_LEN?=?512
Batch_size?=?100
EPOCHS?=?10
#shuffle?the?data
indices?=?np.arange(X.shape[0])
np.random.seed(2018)
np.random.shuffle(indices)
X=X[indices]
Y=Y[indices]
#training?set?validation?set?and?testing?set
nb_validation_samples_val?=?int((VALIDATION_SPLIT?+?TEST_SPLIT)?*?X.shape[0])
nb_validation_samples_test?=?int(TEST_SPLIT?*?X.shape[0])
x_train?=?X[:-nb_validation_samples_val]
y_train?=?Y[:-nb_validation_samples_val]
x_val?=??X[-nb_validation_samples_val:-nb_validation_samples_test]
y_val?=??Y[-nb_validation_samples_val:-nb_validation_samples_test]
x_test?=?X[-nb_validation_samples_test:]
y_test?=?Y[-nb_validation_samples_test:]
#use?tokenizer?to?build?vocab
tokenizer1?=?Tokenizer(num_words=MAX_NUM_WORDS)
tokenizer1.fit_on_texts(df.text)
vocab?=?tokenizer1.word_index
x_train_word_ids?=?tokenizer1.texts_to_sequences(x_train)
x_test_word_ids?=?tokenizer1.texts_to_sequences(x_test)
x_val_word_ids?=?tokenizer1.texts_to_sequences(x_val)
#pad?sequences?into?the?same?length
x_train_padded_seqs?=?pad_sequences(x_train_word_ids?maxlen=MAX_LEN)
x_test_padded_seqs?=?pad_sequences(x_test_word_ids?maxlen=MAX_LEN)
x_val_padded_seqs?=?pad_sequences(x_val_word_ids?maxlen=MAX_LEN)
#slice?sequences?into?many?subsequences
x_test_padded_seqs_split=[]
for?i?in?range(x_test_padded_seqs.shape[0]):
????split1=np.split(x_test_padded_seqs[i]8)
????a=[]
????for?j?in?range(8):
????????s=np.split(split1[j]8)
????????a.append(s)
????x_test_padded_seqs_split.append(a)
x_val_padded_seqs_split=[]
for?i?in?range(x_val_padded_seqs.shape[0]):
????split1=np.split(x_val_padded_seqs[i]8)
????a=[]
????for?j?in?range(8):
????????s=np.split(split1[j]8)
????????a.append(s)
????x_val_padded_seqs_split.append(a)
???
x_train_padded_seqs_split=[]
for?i?in?range(x_train_padded_seqs.shape[0]):
????split1=np.split(x_train_padded_seqs[i]8)
????a=[]
????for?j?in?range(8):
????????s=np.split(split1[j]8)
????
評論
共有 條評論