68 lines
3.3 KiB
Python
68 lines
3.3 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
class AdditiveAttention(nn.Module):
|
||
|
# Implementing the attention module of Bahdanau et al. 2015 where
|
||
|
# score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1))
|
||
|
def __init__(self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None):
|
||
|
super(AdditiveAttention, self).__init__()
|
||
|
|
||
|
if internal_dim is None:
|
||
|
internal_dim = int((encoder_hidden_state_dim + decoder_hidden_state_dim) / 2)
|
||
|
|
||
|
self.w1 = nn.Linear(encoder_hidden_state_dim, internal_dim, bias=False)
|
||
|
self.w2 = nn.Linear(decoder_hidden_state_dim, internal_dim, bias=False)
|
||
|
self.v = nn.Linear(internal_dim, 1, bias=False)
|
||
|
|
||
|
def score(self, encoder_state, decoder_state):
|
||
|
# encoder_state is of shape (batch, enc_dim)
|
||
|
# decoder_state is of shape (batch, dec_dim)
|
||
|
# return value should be of shape (batch, 1)
|
||
|
return self.v(torch.tanh(self.w1(encoder_state) + self.w2(decoder_state)))
|
||
|
|
||
|
def forward(self, encoder_states, decoder_state):
|
||
|
# encoder_states is of shape (batch, num_enc_states, enc_dim)
|
||
|
# decoder_state is of shape (batch, dec_dim)
|
||
|
score_vec = torch.cat([self.score(encoder_states[:, i], decoder_state) for i in range(encoder_states.shape[1])],
|
||
|
dim=1)
|
||
|
# score_vec is of shape (batch, num_enc_states)
|
||
|
|
||
|
attention_probs = torch.unsqueeze(F.softmax(score_vec, dim=1), dim=2)
|
||
|
# attention_probs is of shape (batch, num_enc_states, 1)
|
||
|
|
||
|
final_context_vec = torch.sum(attention_probs * encoder_states, dim=1)
|
||
|
# final_context_vec is of shape (batch, enc_dim)
|
||
|
|
||
|
return final_context_vec, attention_probs
|
||
|
|
||
|
|
||
|
class TemporallyBatchedAdditiveAttention(AdditiveAttention):
|
||
|
# Implementing the attention module of Bahdanau et al. 2015 where
|
||
|
# score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1))
|
||
|
def __init__(self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None):
|
||
|
super(TemporallyBatchedAdditiveAttention, self).__init__(encoder_hidden_state_dim,
|
||
|
decoder_hidden_state_dim,
|
||
|
internal_dim)
|
||
|
|
||
|
def score(self, encoder_state, decoder_state):
|
||
|
# encoder_state is of shape (batch, num_enc_states, max_time, enc_dim)
|
||
|
# decoder_state is of shape (batch, max_time, dec_dim)
|
||
|
# return value should be of shape (batch, num_enc_states, max_time, 1)
|
||
|
return self.v(torch.tanh(self.w1(encoder_state) + torch.unsqueeze(self.w2(decoder_state), dim=1)))
|
||
|
|
||
|
def forward(self, encoder_states, decoder_state):
|
||
|
# encoder_states is of shape (batch, num_enc_states, max_time, enc_dim)
|
||
|
# decoder_state is of shape (batch, max_time, dec_dim)
|
||
|
score_vec = self.score(encoder_states, decoder_state)
|
||
|
# score_vec is of shape (batch, num_enc_states, max_time, 1)
|
||
|
|
||
|
attention_probs = F.softmax(score_vec, dim=1)
|
||
|
# attention_probs is of shape (batch, num_enc_states, max_time, 1)
|
||
|
|
||
|
final_context_vec = torch.sum(attention_probs * encoder_states, dim=1)
|
||
|
# final_context_vec is of shape (batch, max_time, enc_dim)
|
||
|
|
||
|
return final_context_vec, torch.squeeze(torch.transpose(attention_probs, 1, 2), dim=3)
|