import numpy as np import torch def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0): return v.reshape( torch.Size([1] * n_dim_to_prepend) + v.shape + torch.Size([1] * n_dim_to_append)) def block_diag(m): """ Make a block diagonal matrix along dim=-3 EXAMPLE: block_diag(torch.ones(4,3,2)) should give a 12 x 8 matrix with blocks of 3 x 2 ones. Prepend batch dimensions if needed. You can also give a list of matrices. :type m: torch.Tensor, list :rtype: torch.Tensor """ if type(m) is list: m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3) d = m.dim() n = m.shape[-3] siz0 = m.shape[:-3] siz1 = m.shape[-2:] m2 = m.unsqueeze(-2) eye = attach_dim(torch.eye(n, device=m.device).unsqueeze(-2), d - 3, 1) return (m2 * eye).reshape(siz0 + torch.Size(torch.tensor(siz1) * n)) def tile(a, dim, n_tile, device='cpu'): init_dim = a.size(dim) repeat_idx = [1] * a.dim() repeat_idx[dim] = n_tile a = a.repeat(*(repeat_idx)) order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) return torch.index_select(a, dim, order_index)