-
大小: 6KB文件類型: .zip金幣: 2下載: 1 次發(fā)布日期: 2021-06-17
- 語(yǔ)言: Python
- 標(biāo)簽:
資源簡(jiǎn)介
在PyTorch中關(guān)注神經(jīng)機(jī)器翻譯的最小Seq2Seq模型

代碼片段和文件信息
import?math
import?torch
import?random
from?torch?import?nn
from?torch.autograd?import?Variable
import?torch.nn.functional?as?F
class?Encoder(nn.Module):
????def?__init__(self?input_size?embed_size?hidden_size
?????????????????n_layers=1?dropout=0.5):
????????super(Encoder?self).__init__()
????????self.input_size?=?input_size
????????self.hidden_size?=?hidden_size
????????self.embed_size?=?embed_size
????????self.embed?=?nn.embedding(input_size?embed_size)
????????self.gru?=?nn.GRU(embed_size?hidden_size?n_layers
??????????????????????????dropout=dropout?bidirectional=True)
????def?forward(self?src?hidden=None):
????????embedded?=?self.embed(src)
????????outputs?hidden?=?self.gru(embedded?hidden)
????????#?sum?bidirectional?outputs
????????outputs?=?(outputs[:?:?:self.hidden_size]?+
???????????????????outputs[:?:?self.hidden_size:])
????????return?outputs?hidden
class?Attention(nn.Module):
????def?__init__(self?hidden_size):
????????super(Attention?self).__init__()
????????self.hidden_size?=?hidden_size
????????self.attn?=?nn.Linear(self.hidden_size?*?2?hidden_size)
????????self.v?=?nn.Parameter(torch.rand(hidden_size))
????????stdv?=?1.?/?math.sqrt(self.v.size(0))
????????self.v.data.uniform_(-stdv?stdv)
????def?forward(self?hidden?encoder_outputs):
????????timestep?=?encoder_outputs.size(0)
????????h?=?hidden.repeat(timestep?1?1).transpose(0?1)
????????encoder_outputs?=?encoder_outputs.transpose(0?1)??#?[B*T*H]
????????attn_energies?=?self.score(h?encoder_outputs)
????????return?F.relu(attn_energies?dim=1).unsqueeze(1)
????def?score(self?hidden?encoder_outputs):
????????#?[B*T*2H]->[B*T*H]
????????energy?=?F.softmax(self.attn(torch.cat([hidden?encoder_outputs]?2)))
????????energy?=?energy.transpose(1?2)??#?[B*H*T]
????????v?=?self.v.repeat(encoder_outputs.size(0)?1).unsqueeze(1)??#?[B*1*H]
????????energy?=?torch.bmm(v?energy)??#?[B*1*T]
????????return?energy.squeeze(1)??#?[B*T]
class?Decoder(nn.Module):
????def?__init__(self?embed_size?hidden_size?output_size
?????????????????n_layers=1?dropout=0.2):
????????super(Decoder?self).__init__()
????????self.embed_size?=?embed_size
????????self.hidden_size?=?hidden_size
????????self.output_size?=?output_size
????????self.n_layers?=?n_layers
????????self.embed?=?nn.embedding(output_size?embed_size)
????????self.dropout?=?nn.Dropout(dropout?inplace=True)
????????self.attention?=?Attention(hidden_size)
????????self.gru?=?nn.GRU(hidden_size?+?embed_size?hidden_size
??????????????????????????n_layers?dropout=dropout)
????????self.out?=?nn.Linear(hidden_size?*?2?output_size)
????def?forward(self?input?last_hidden?encoder_outputs):
????????#?Get?the?embedding?of?the?current?input?word?(last?output?word)
????????embedded?=?self.embed(input).unsqueeze(0)??#?(1BN)
????????embedded?=?self.dropout(embedded)
????????#?Calculate?attention?weights?and?apply?to?encoder?outputs
????????attn_weights?=?self.attention(last_hidden[-1]?encoder_outputs)
???
?屬性????????????大小?????日期????時(shí)間???名稱
-----------?---------??----------?-----??----
?????目錄???????????0??2018-10-18?14:14??seq2seq-master\
?????文件????????1204??2018-10-18?14:14??seq2seq-master\.gitignore
?????文件????????1061??2018-10-18?14:14??seq2seq-master\LICENSE
?????文件????????1247??2018-10-18?14:14??seq2seq-master\README.md
?????文件????????4509??2018-10-18?14:14??seq2seq-master\model.py
?????文件????????4155??2018-10-18?14:14??seq2seq-master\train.py
?????文件????????1037??2018-10-18?14:14??seq2seq-master\utils.py
評(píng)論
共有 條評(píng)論