Trajectron-plus-plus/trajectron/utils/matrix_utils.py

41 lines
1.2 KiB
Python
Raw Normal View History

import numpy as np
import torch
def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0):
return v.reshape(
torch.Size([1] * n_dim_to_prepend)
+ v.shape
+ torch.Size([1] * n_dim_to_append))
def block_diag(m):
"""
Make a block diagonal matrix along dim=-3
EXAMPLE:
block_diag(torch.ones(4,3,2))
should give a 12 x 8 matrix with blocks of 3 x 2 ones.
Prepend batch dimensions if needed.
You can also give a list of matrices.
:type m: torch.Tensor, list
:rtype: torch.Tensor
"""
if type(m) is list:
m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3)
d = m.dim()
n = m.shape[-3]
siz0 = m.shape[:-3]
siz1 = m.shape[-2:]
m2 = m.unsqueeze(-2)
eye = attach_dim(torch.eye(n, device=m.device).unsqueeze(-2), d - 3, 1)
return (m2 * eye).reshape(siz0 + torch.Size(torch.tensor(siz1) * n))
def tile(a, dim, n_tile, device='cpu'):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
return torch.index_select(a, dim, order_index)