Trajectron-plus-plus/trajectron/model/components/additive_attention.py
2020-04-05 21:43:49 -04:00

67 lines
3.3 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class AdditiveAttention(nn.Module):
# Implementing the attention module of Bahdanau et al. 2015 where
# score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1))
def __init__(self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None):
super(AdditiveAttention, self).__init__()
if internal_dim is None:
internal_dim = int((encoder_hidden_state_dim + decoder_hidden_state_dim) / 2)
self.w1 = nn.Linear(encoder_hidden_state_dim, internal_dim, bias=False)
self.w2 = nn.Linear(decoder_hidden_state_dim, internal_dim, bias=False)
self.v = nn.Linear(internal_dim, 1, bias=False)
def score(self, encoder_state, decoder_state):
# encoder_state is of shape (batch, enc_dim)
# decoder_state is of shape (batch, dec_dim)
# return value should be of shape (batch, 1)
return self.v(torch.tanh(self.w1(encoder_state) + self.w2(decoder_state)))
def forward(self, encoder_states, decoder_state):
# encoder_states is of shape (batch, num_enc_states, enc_dim)
# decoder_state is of shape (batch, dec_dim)
score_vec = torch.cat([self.score(encoder_states[:, i], decoder_state) for i in range(encoder_states.shape[1])],
dim=1)
# score_vec is of shape (batch, num_enc_states)
attention_probs = torch.unsqueeze(F.softmax(score_vec, dim=1), dim=2)
# attention_probs is of shape (batch, num_enc_states, 1)
final_context_vec = torch.sum(attention_probs * encoder_states, dim=1)
# final_context_vec is of shape (batch, enc_dim)
return final_context_vec, attention_probs
class TemporallyBatchedAdditiveAttention(AdditiveAttention):
# Implementing the attention module of Bahdanau et al. 2015 where
# score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1))
def __init__(self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None):
super(TemporallyBatchedAdditiveAttention, self).__init__(encoder_hidden_state_dim,
decoder_hidden_state_dim,
internal_dim)
def score(self, encoder_state, decoder_state):
# encoder_state is of shape (batch, num_enc_states, max_time, enc_dim)
# decoder_state is of shape (batch, max_time, dec_dim)
# return value should be of shape (batch, num_enc_states, max_time, 1)
return self.v(torch.tanh(self.w1(encoder_state) + torch.unsqueeze(self.w2(decoder_state), dim=1)))
def forward(self, encoder_states, decoder_state):
# encoder_states is of shape (batch, num_enc_states, max_time, enc_dim)
# decoder_state is of shape (batch, max_time, dec_dim)
score_vec = self.score(encoder_states, decoder_state)
# score_vec is of shape (batch, num_enc_states, max_time, 1)
attention_probs = F.softmax(score_vec, dim=1)
# attention_probs is of shape (batch, num_enc_states, max_time, 1)
final_context_vec = torch.sum(attention_probs * encoder_states, dim=1)
# final_context_vec is of shape (batch, max_time, enc_dim)
return final_context_vec, torch.squeeze(torch.transpose(attention_probs, 1, 2), dim=3)