41 lines
No EOL
1.2 KiB
Python
41 lines
No EOL
1.2 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
|
|
def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0):
|
|
return v.reshape(
|
|
torch.Size([1] * n_dim_to_prepend)
|
|
+ v.shape
|
|
+ torch.Size([1] * n_dim_to_append))
|
|
|
|
|
|
def block_diag(m):
|
|
"""
|
|
Make a block diagonal matrix along dim=-3
|
|
EXAMPLE:
|
|
block_diag(torch.ones(4,3,2))
|
|
should give a 12 x 8 matrix with blocks of 3 x 2 ones.
|
|
Prepend batch dimensions if needed.
|
|
You can also give a list of matrices.
|
|
:type m: torch.Tensor, list
|
|
:rtype: torch.Tensor
|
|
"""
|
|
if type(m) is list:
|
|
m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3)
|
|
|
|
d = m.dim()
|
|
n = m.shape[-3]
|
|
siz0 = m.shape[:-3]
|
|
siz1 = m.shape[-2:]
|
|
m2 = m.unsqueeze(-2)
|
|
eye = attach_dim(torch.eye(n, device=m.device).unsqueeze(-2), d - 3, 1)
|
|
return (m2 * eye).reshape(siz0 + torch.Size(torch.tensor(siz1) * n))
|
|
|
|
|
|
def tile(a, dim, n_tile, device='cpu'):
|
|
init_dim = a.size(dim)
|
|
repeat_idx = [1] * a.dim()
|
|
repeat_idx[dim] = n_tile
|
|
a = a.repeat(*(repeat_idx))
|
|
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
|
|
return torch.index_select(a, dim, order_index) |