Trajectron-plus-plus/trajectron/model/components/graph_attention.py

58 lines
2.1 KiB
Python
Raw Normal View History

import warnings
import math
import numbers
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init, Parameter
class GraphMultiTypeAttention(nn.Module):
def __init__(self, in_features, hidden_features, out_features, bias=True, types=1):
super(GraphMultiTypeAttention, self).__init__()
self.types = types
self.in_features = in_features
self.out_features = out_features
self.node_self_loop_weight = Parameter(torch.Tensor(hidden_features, in_features[0]))
self.weight_per_type = nn.ParameterList()
for i in range(types):
self.weight_per_type.append(Parameter(torch.Tensor(hidden_features, in_features[i])))
if bias:
self.bias = Parameter(torch.Tensor(hidden_features))
else:
self.register_parameter('bias', None)
self.linear_to_out = nn.Linear(hidden_features, out_features, bias=bias)
self.reset_parameters()
def reset_parameters(self):
for weight in self.weight_per_type:
bound = 1 / math.sqrt(weight.size(1))
init.uniform_(weight, -bound, bound)
bound = 1 / math.sqrt(self.node_self_loop_weight.size(1))
init.uniform_(self.node_self_loop_weight, -bound, bound)
if self.bias is not None:
init.uniform_(self.bias, -bound, bound)
def forward(self, inputs, types, edge_weights):
weight_list = list()
for i, type in enumerate(types):
weight_list.append((edge_weights[i] / len(edge_weights)) * self.weight_per_type[type].T)
weight_list.append(self.node_self_loop_weight.T)
weight = torch.cat(weight_list, dim=0)
stacked_input = torch.cat(inputs, dim=-1)
output = stacked_input.matmul(weight)
output = output
if self.bias is not None:
output += self.bias
return torch.relu(self.linear_to_out(torch.relu(output)))
def extra_repr(self):
return 'in_features={}, hidden_features={},, out_features={}, types={}, bias={}'.format(
self.in_features, self.hidden_features, self.out_features, self.types, self.bias is not None
)