471 lines
19 KiB
Python
471 lines
19 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from typing import Tuple, Optional
|
||
|
|
||
|
|
||
|
pi = torch.tensor(3.14159265358979323846)
|
||
|
|
||
|
|
||
|
def deg2rad(tensor: torch.Tensor) -> torch.Tensor:
|
||
|
r"""Function that converts angles from degrees to radians.
|
||
|
Args:
|
||
|
tensor (torch.Tensor): Tensor of arbitrary shape.
|
||
|
Returns:
|
||
|
torch.Tensor: tensor with same shape as input.
|
||
|
"""
|
||
|
if not isinstance(tensor, torch.Tensor):
|
||
|
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
||
|
type(tensor)))
|
||
|
|
||
|
return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
|
||
|
|
||
|
|
||
|
def angle_to_rotation_matrix(angle: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Creates a rotation matrix out of angles in degrees
|
||
|
Args:
|
||
|
angle: (torch.Tensor): tensor of angles in degrees, any shape.
|
||
|
Returns:
|
||
|
torch.Tensor: tensor of *x2x2 rotation matrices.
|
||
|
Shape:
|
||
|
- Input: :math:`(*)`
|
||
|
- Output: :math:`(*, 2, 2)`
|
||
|
Example:
|
||
|
>>> input = torch.rand(1, 3) # Nx3
|
||
|
>>> output = kornia.angle_to_rotation_matrix(input) # Nx3x2x2
|
||
|
"""
|
||
|
ang_rad = deg2rad(angle)
|
||
|
cos_a: torch.Tensor = torch.cos(ang_rad)
|
||
|
sin_a: torch.Tensor = torch.sin(ang_rad)
|
||
|
return torch.stack([cos_a, sin_a, -sin_a, cos_a], dim=-1).view(*angle.shape, 2, 2)
|
||
|
|
||
|
|
||
|
def get_rotation_matrix2d(
|
||
|
center: torch.Tensor,
|
||
|
angle: torch.Tensor,
|
||
|
scale: torch.Tensor) -> torch.Tensor:
|
||
|
r"""Calculates an affine matrix of 2D rotation.
|
||
|
The function calculates the following matrix:
|
||
|
.. math::
|
||
|
\begin{bmatrix}
|
||
|
\alpha & \beta & (1 - \alpha) \cdot \text{x}
|
||
|
- \beta \cdot \text{y} \\
|
||
|
-\beta & \alpha & \beta \cdot \text{x}
|
||
|
+ (1 - \alpha) \cdot \text{y}
|
||
|
\end{bmatrix}
|
||
|
where
|
||
|
.. math::
|
||
|
\alpha = \text{scale} \cdot cos(\text{radian}) \\
|
||
|
\beta = \text{scale} \cdot sin(\text{radian})
|
||
|
The transformation maps the rotation center to itself
|
||
|
If this is not the target, adjust the shift.
|
||
|
Args:
|
||
|
center (Tensor): center of the rotation in the source image.
|
||
|
angle (Tensor): rotation radian in degrees. Positive values mean
|
||
|
counter-clockwise rotation (the coordinate origin is assumed to
|
||
|
be the top-left corner).
|
||
|
scale (Tensor): isotropic scale factor.
|
||
|
Returns:
|
||
|
Tensor: the affine matrix of 2D rotation.
|
||
|
Shape:
|
||
|
- Input: :math:`(B, 2)`, :math:`(B)` and :math:`(B)`
|
||
|
- Output: :math:`(B, 2, 3)`
|
||
|
Example:
|
||
|
>>> center = torch.zeros(1, 2)
|
||
|
>>> scale = torch.ones(1)
|
||
|
>>> radian = 45. * torch.ones(1)
|
||
|
>>> M = kornia.get_rotation_matrix2d(center, radian, scale)
|
||
|
tensor([[[ 0.7071, 0.7071, 0.0000],
|
||
|
[-0.7071, 0.7071, 0.0000]]])
|
||
|
"""
|
||
|
if not torch.is_tensor(center):
|
||
|
raise TypeError("Input center type is not a torch.Tensor. Got {}"
|
||
|
.format(type(center)))
|
||
|
if not torch.is_tensor(angle):
|
||
|
raise TypeError("Input radian type is not a torch.Tensor. Got {}"
|
||
|
.format(type(angle)))
|
||
|
if not torch.is_tensor(scale):
|
||
|
raise TypeError("Input scale type is not a torch.Tensor. Got {}"
|
||
|
.format(type(scale)))
|
||
|
if not (len(center.shape) == 2 and center.shape[1] == 2):
|
||
|
raise ValueError("Input center must be a Bx2 tensor. Got {}"
|
||
|
.format(center.shape))
|
||
|
if not len(angle.shape) == 1:
|
||
|
raise ValueError("Input radian must be a B tensor. Got {}"
|
||
|
.format(angle.shape))
|
||
|
if not len(scale.shape) == 1:
|
||
|
raise ValueError("Input scale must be a B tensor. Got {}"
|
||
|
.format(scale.shape))
|
||
|
if not (center.shape[0] == angle.shape[0] == scale.shape[0]):
|
||
|
raise ValueError("Inputs must have same batch size dimension. Got {}"
|
||
|
.format(center.shape, angle.shape, scale.shape))
|
||
|
# convert radian and apply scale
|
||
|
scaled_rotation: torch.Tensor = angle_to_rotation_matrix(angle) * scale.view(-1, 1, 1)
|
||
|
alpha: torch.Tensor = scaled_rotation[:, 0, 0]
|
||
|
beta: torch.Tensor = scaled_rotation[:, 0, 1]
|
||
|
|
||
|
# unpack the center to x, y coordinates
|
||
|
x: torch.Tensor = center[..., 0]
|
||
|
y: torch.Tensor = center[..., 1]
|
||
|
|
||
|
# create output tensor
|
||
|
batch_size: int = center.shape[0]
|
||
|
M: torch.Tensor = torch.zeros(
|
||
|
batch_size, 2, 3, device=center.device, dtype=center.dtype)
|
||
|
M[..., 0:2, 0:2] = scaled_rotation
|
||
|
M[..., 0, 2] = (torch.tensor(1.) - alpha) * x - beta * y
|
||
|
M[..., 1, 2] = beta * x + (torch.tensor(1.) - alpha) * y
|
||
|
return M
|
||
|
|
||
|
def convert_points_to_homogeneous(points: torch.Tensor) -> torch.Tensor:
|
||
|
r"""Function that converts points from Euclidean to homogeneous space.
|
||
|
Examples::
|
||
|
>>> input = torch.rand(2, 4, 3) # BxNx3
|
||
|
>>> output = kornia.convert_points_to_homogeneous(input) # BxNx4
|
||
|
"""
|
||
|
if not isinstance(points, torch.Tensor):
|
||
|
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
||
|
type(points)))
|
||
|
if len(points.shape) < 2:
|
||
|
raise ValueError("Input must be at least a 2D tensor. Got {}".format(
|
||
|
points.shape))
|
||
|
|
||
|
return torch.nn.functional.pad(points, [0, 1], "constant", 1.0)
|
||
|
|
||
|
|
||
|
def convert_points_from_homogeneous(
|
||
|
points: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||
|
r"""Function that converts points from homogeneous to Euclidean space.
|
||
|
Examples::
|
||
|
>>> input = torch.rand(2, 4, 3) # BxNx3
|
||
|
>>> output = kornia.convert_points_from_homogeneous(input) # BxNx2
|
||
|
"""
|
||
|
if not isinstance(points, torch.Tensor):
|
||
|
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
||
|
type(points)))
|
||
|
|
||
|
if len(points.shape) < 2:
|
||
|
raise ValueError("Input must be at least a 2D tensor. Got {}".format(
|
||
|
points.shape))
|
||
|
|
||
|
# we check for points at infinity
|
||
|
z_vec: torch.Tensor = points[..., -1:]
|
||
|
|
||
|
# set the results of division by zeror/near-zero to 1.0
|
||
|
# follow the convention of opencv:
|
||
|
# https://github.com/opencv/opencv/pull/14411/files
|
||
|
mask: torch.Tensor = torch.abs(z_vec) > eps
|
||
|
scale: torch.Tensor = torch.ones_like(z_vec).masked_scatter_(
|
||
|
mask, torch.tensor(1.0).to(points.device) / z_vec[mask])
|
||
|
|
||
|
return scale * points[..., :-1]
|
||
|
|
||
|
def transform_points(trans_01: torch.Tensor,
|
||
|
points_1: torch.Tensor) -> torch.Tensor:
|
||
|
r"""Function that applies transformations to a set of points.
|
||
|
Args:
|
||
|
trans_01 (torch.Tensor): tensor for transformations of shape
|
||
|
:math:`(B, D+1, D+1)`.
|
||
|
points_1 (torch.Tensor): tensor of points of shape :math:`(B, N, D)`.
|
||
|
Returns:
|
||
|
torch.Tensor: tensor of N-dimensional points.
|
||
|
Shape:
|
||
|
- Output: :math:`(B, N, D)`
|
||
|
Examples:
|
||
|
>>> points_1 = torch.rand(2, 4, 3) # BxNx3
|
||
|
>>> trans_01 = torch.eye(4).view(1, 4, 4) # Bx4x4
|
||
|
>>> points_0 = kornia.transform_points(trans_01, points_1) # BxNx3
|
||
|
"""
|
||
|
if not torch.is_tensor(trans_01) or not torch.is_tensor(points_1):
|
||
|
raise TypeError("Input type is not a torch.Tensor")
|
||
|
if not trans_01.device == points_1.device:
|
||
|
raise TypeError("Tensor must be in the same device")
|
||
|
if not trans_01.shape[0] == points_1.shape[0] and trans_01.shape[0] != 1:
|
||
|
raise ValueError("Input batch size must be the same for both tensors or 1")
|
||
|
if not trans_01.shape[-1] == (points_1.shape[-1] + 1):
|
||
|
raise ValueError("Last input dimensions must differe by one unit")
|
||
|
# to homogeneous
|
||
|
points_1_h = convert_points_to_homogeneous(points_1) # BxNxD+1
|
||
|
# transform coordinates
|
||
|
points_0_h = torch.matmul(
|
||
|
trans_01.unsqueeze(1), points_1_h.unsqueeze(-1))
|
||
|
points_0_h = torch.squeeze(points_0_h, dim=-1)
|
||
|
# to euclidean
|
||
|
points_0 = convert_points_from_homogeneous(points_0_h) # BxNxD
|
||
|
return points_0
|
||
|
|
||
|
|
||
|
def multi_linspace(a, b, num, endpoint=True, device='cpu', dtype=torch.float):
|
||
|
"""This function is just like np.linspace, but will create linearly
|
||
|
spaced vectors from a start to end vector.
|
||
|
Inputs:
|
||
|
a - Start vector.
|
||
|
b - End vector.
|
||
|
num - Number of samples to generate. Default is 50. Must be above 0.
|
||
|
endpoint - If True, b is the last sample.
|
||
|
Otherwise, it is not included. Default is True.
|
||
|
"""
|
||
|
|
||
|
return a[..., None] + (b-a)[..., None]/(num-endpoint) * torch.arange(num, device=device, dtype=dtype)
|
||
|
|
||
|
|
||
|
def create_batched_meshgrid(
|
||
|
x_min: torch.Tensor,
|
||
|
y_min: torch.Tensor,
|
||
|
x_max: torch.Tensor,
|
||
|
y_max: torch.Tensor,
|
||
|
height: int,
|
||
|
width: int,
|
||
|
device: Optional[torch.device] = torch.device('cpu')) -> torch.Tensor:
|
||
|
"""Generates a coordinate grid for an image.
|
||
|
When the flag `normalized_coordinates` is set to True, the grid is
|
||
|
normalized to be in the range [-1,1] to be consistent with the pytorch
|
||
|
function grid_sample.
|
||
|
http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample
|
||
|
Args:
|
||
|
height (int): the image height (rows).
|
||
|
width (int): the image width (cols).
|
||
|
normalized_coordinates (Optional[bool]): whether to normalize
|
||
|
coordinates in the range [-1, 1] in order to be consistent with the
|
||
|
PyTorch function grid_sample.
|
||
|
Return:
|
||
|
torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`.
|
||
|
"""
|
||
|
# generate coordinates
|
||
|
xs = multi_linspace(x_min, x_max, width, device=device, dtype=torch.float)
|
||
|
ys = multi_linspace(y_min, y_max, height, device=device, dtype=torch.float)
|
||
|
|
||
|
# generate grid by stacking coordinates
|
||
|
bs = x_min.shape[0]
|
||
|
batched_grid_i_list = list()
|
||
|
for i in range(bs):
|
||
|
batched_grid_i_list.append(torch.stack(torch.meshgrid([xs[i], ys[i]])).transpose(1, 2)) # 2xHxW
|
||
|
batched_grid: torch.Tensor = torch.stack(batched_grid_i_list, dim=0)
|
||
|
return batched_grid.permute(0, 2, 3, 1) # BxHxWx2
|
||
|
|
||
|
|
||
|
def homography_warp(patch_src: torch.Tensor,
|
||
|
centers: torch.Tensor,
|
||
|
dst_homo_src: torch.Tensor,
|
||
|
dsize: Tuple[int, int],
|
||
|
mode: str = 'bilinear',
|
||
|
padding_mode: str = 'zeros') -> torch.Tensor:
|
||
|
r"""Function that warps image patchs or tensors by homographies.
|
||
|
See :class:`~kornia.geometry.warp.HomographyWarper` for details.
|
||
|
Args:
|
||
|
patch_src (torch.Tensor): The image or tensor to warp. Should be from
|
||
|
source of shape :math:`(N, C, H, W)`.
|
||
|
dst_homo_src (torch.Tensor): The homography or stack of homographies
|
||
|
from source to destination of shape
|
||
|
:math:`(N, 3, 3)`.
|
||
|
dsize (Tuple[int, int]): The height and width of the image to warp.
|
||
|
mode (str): interpolation mode to calculate output values
|
||
|
'bilinear' | 'nearest'. Default: 'bilinear'.
|
||
|
padding_mode (str): padding mode for outside grid values
|
||
|
'zeros' | 'border' | 'reflection'. Default: 'zeros'.
|
||
|
Return:
|
||
|
torch.Tensor: Patch sampled at locations from source to destination.
|
||
|
Example:
|
||
|
>>> input = torch.rand(1, 3, 32, 32)
|
||
|
>>> homography = torch.eye(3).view(1, 3, 3)
|
||
|
>>> output = kornia.homography_warp(input, homography, (32, 32))
|
||
|
"""
|
||
|
|
||
|
out_height, out_width = dsize
|
||
|
image_height, image_width = patch_src.shape[-2:]
|
||
|
x_min = 2. * (centers[..., 0] - out_width/2) / image_width - 1.
|
||
|
y_min = 2. * (centers[..., 1] - out_height/2) / image_height - 1.
|
||
|
x_max = 2. * (centers[..., 0] + out_width/2) / image_width - 1.
|
||
|
y_max = 2. * (centers[..., 1] + out_height/2) / image_height - 1.
|
||
|
warper = HomographyWarper(x_min, y_min, x_max, y_max, out_height, out_width, mode, padding_mode)
|
||
|
return warper(patch_src, dst_homo_src)
|
||
|
|
||
|
|
||
|
def normal_transform_pixel(height, width):
|
||
|
|
||
|
tr_mat = torch.Tensor([[1.0, 0.0, -1.0],
|
||
|
[0.0, 1.0, -1.0],
|
||
|
[0.0, 0.0, 1.0]]) # 1x3x3
|
||
|
|
||
|
tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / (width - 1.0)
|
||
|
tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / (height - 1.0)
|
||
|
|
||
|
tr_mat = tr_mat.unsqueeze(0)
|
||
|
|
||
|
return tr_mat
|
||
|
|
||
|
|
||
|
def src_norm_to_dst_norm(dst_pix_trans_src_pix: torch.Tensor,
|
||
|
dsize_src: Tuple[int, int], dsize_dst: Tuple[int, int]) -> torch.Tensor:
|
||
|
# source and destination sizes
|
||
|
src_h, src_w = dsize_src
|
||
|
dst_h, dst_w = dsize_dst
|
||
|
# the devices and types
|
||
|
device: torch.device = dst_pix_trans_src_pix.device
|
||
|
dtype: torch.dtype = dst_pix_trans_src_pix.dtype
|
||
|
# compute the transformation pixel/norm for src/dst
|
||
|
src_norm_trans_src_pix: torch.Tensor = normal_transform_pixel(
|
||
|
src_h, src_w).to(device, dtype)
|
||
|
src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix)
|
||
|
dst_norm_trans_dst_pix: torch.Tensor = normal_transform_pixel(
|
||
|
dst_h, dst_w).to(device, dtype)
|
||
|
# compute chain transformations
|
||
|
dst_norm_trans_src_norm: torch.Tensor = (
|
||
|
dst_norm_trans_dst_pix @ (dst_pix_trans_src_pix @ src_pix_trans_src_norm)
|
||
|
)
|
||
|
return dst_norm_trans_src_norm
|
||
|
|
||
|
|
||
|
def transform_warp_impl(src: torch.Tensor, centers: torch.Tensor, dst_pix_trans_src_pix: torch.Tensor,
|
||
|
dsize_src: Tuple[int, int], dsize_dst: Tuple[int, int],
|
||
|
grid_mode: str, padding_mode: str) -> torch.Tensor:
|
||
|
"""Compute the transform in normalized cooridnates and perform the warping.
|
||
|
"""
|
||
|
dst_norm_trans_src_norm: torch.Tensor = src_norm_to_dst_norm(
|
||
|
dst_pix_trans_src_pix, dsize_src, dsize_src)
|
||
|
|
||
|
src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm)
|
||
|
return homography_warp(src, centers, src_norm_trans_dst_norm, dsize_dst, grid_mode, padding_mode)
|
||
|
|
||
|
|
||
|
class HomographyWarper(nn.Module):
|
||
|
r"""Warps image patches or tensors by homographies.
|
||
|
.. math::
|
||
|
X_{dst} = H_{src}^{\{dst\}} * X_{src}
|
||
|
Args:
|
||
|
height (int): The height of the image to warp.
|
||
|
width (int): The width of the image to warp.
|
||
|
mode (str): interpolation mode to calculate output values
|
||
|
'bilinear' | 'nearest'. Default: 'bilinear'.
|
||
|
padding_mode (str): padding mode for outside grid values
|
||
|
'zeros' | 'border' | 'reflection'. Default: 'zeros'.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x_min: torch.Tensor,
|
||
|
y_min: torch.Tensor,
|
||
|
x_max: torch.Tensor,
|
||
|
y_max: torch.Tensor,
|
||
|
height: int,
|
||
|
width: int,
|
||
|
mode: str = 'bilinear',
|
||
|
padding_mode: str = 'zeros') -> None:
|
||
|
super(HomographyWarper, self).__init__()
|
||
|
self.width: int = width
|
||
|
self.height: int = height
|
||
|
self.mode: str = mode
|
||
|
self.padding_mode: str = padding_mode
|
||
|
|
||
|
# create base grid to compute the flow
|
||
|
self.grid: torch.Tensor = create_batched_meshgrid(x_min, y_min, x_max, y_max, height, width)
|
||
|
|
||
|
def warp_grid(self, dst_homo_src: torch.Tensor) -> torch.Tensor:
|
||
|
r"""Computes the grid to warp the coordinates grid by an homography.
|
||
|
Args:
|
||
|
dst_homo_src (torch.Tensor): Homography or homographies (stacked) to
|
||
|
transform all points in the grid. Shape of the
|
||
|
homography has to be :math:`(N, 3, 3)`.
|
||
|
Returns:
|
||
|
torch.Tensor: the transformed grid of shape :math:`(N, H, W, 2)`.
|
||
|
"""
|
||
|
batch_size: int = dst_homo_src.shape[0]
|
||
|
device: torch.device = dst_homo_src.device
|
||
|
dtype: torch.dtype = dst_homo_src.dtype
|
||
|
# expand grid to match the input batch size
|
||
|
grid: torch.Tensor = self.grid
|
||
|
if len(dst_homo_src.shape) == 3: # local homography case
|
||
|
dst_homo_src = dst_homo_src.view(batch_size, 1, 3, 3) # NxHxWx3x3
|
||
|
# perform the actual grid transformation,
|
||
|
# the grid is copied to input device and casted to the same type
|
||
|
flow: torch.Tensor = transform_points(
|
||
|
dst_homo_src, grid.to(device).to(dtype)) # NxHxWx2
|
||
|
return flow.view(batch_size, self.height, self.width, 2) # NxHxWx2
|
||
|
|
||
|
def forward( # type: ignore
|
||
|
self,
|
||
|
patch_src: torch.Tensor,
|
||
|
dst_homo_src: torch.Tensor) -> torch.Tensor:
|
||
|
r"""Warps an image or tensor from source into reference frame.
|
||
|
Args:
|
||
|
patch_src (torch.Tensor): The image or tensor to warp.
|
||
|
Should be from source.
|
||
|
dst_homo_src (torch.Tensor): The homography or stack of homographies
|
||
|
from source to destination. The homography assumes normalized
|
||
|
coordinates [-1, 1].
|
||
|
Return:
|
||
|
torch.Tensor: Patch sampled at locations from source to destination.
|
||
|
Shape:
|
||
|
- Input: :math:`(N, C, H, W)` and :math:`(N, 3, 3)`
|
||
|
- Output: :math:`(N, C, H, W)`
|
||
|
Example:
|
||
|
>>> input = torch.rand(1, 3, 32, 32)
|
||
|
>>> homography = torch.eye(3).view(1, 3, 3)
|
||
|
>>> warper = kornia.HomographyWarper(32, 32)
|
||
|
>>> output = warper(input, homography) # NxCxHxW
|
||
|
"""
|
||
|
if not dst_homo_src.device == patch_src.device:
|
||
|
raise TypeError("Patch and homography must be on the same device. \
|
||
|
Got patch.device: {} dst_H_src.device: {}."
|
||
|
.format(patch_src.device, dst_homo_src.device))
|
||
|
|
||
|
return F.grid_sample(patch_src, self.warp_grid(dst_homo_src), # type: ignore
|
||
|
mode=self.mode, padding_mode=self.padding_mode, align_corners=True)
|
||
|
|
||
|
|
||
|
def warp_affine_crop(src: torch.Tensor, centers: torch.Tensor, M: torch.Tensor,
|
||
|
dsize: Tuple[int, int], flags: str = 'bilinear',
|
||
|
padding_mode: str = 'zeros') -> torch.Tensor:
|
||
|
r"""Applies an affine transformation to a tensor.
|
||
|
|
||
|
The function warp_affine transforms the source tensor using
|
||
|
the specified matrix:
|
||
|
|
||
|
.. math::
|
||
|
\text{dst}(x, y) = \text{src} \left( M_{11} x + M_{12} y + M_{13} ,
|
||
|
M_{21} x + M_{22} y + M_{23} \right )
|
||
|
|
||
|
Args:
|
||
|
src (torch.Tensor): input tensor of shape :math:`(B, C, H, W)`.
|
||
|
M (torch.Tensor): affine transformation of shape :math:`(B, 2, 3)`.
|
||
|
dsize (Tuple[int, int]): size of the output image (height, width).
|
||
|
mode (str): interpolation mode to calculate output values
|
||
|
'bilinear' | 'nearest'. Default: 'bilinear'.
|
||
|
padding_mode (str): padding mode for outside grid values
|
||
|
'zeros' | 'border' | 'reflection'. Default: 'zeros'.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: the warped tensor.
|
||
|
|
||
|
Shape:
|
||
|
- Output: :math:`(B, C, H, W)`
|
||
|
|
||
|
.. note::
|
||
|
See a working example `here <https://kornia.readthedocs.io/en/latest/
|
||
|
tutorials/warp_affine.html>`__.
|
||
|
"""
|
||
|
if not torch.is_tensor(src):
|
||
|
raise TypeError("Input src type is not a torch.Tensor. Got {}"
|
||
|
.format(type(src)))
|
||
|
|
||
|
if not torch.is_tensor(M):
|
||
|
raise TypeError("Input M type is not a torch.Tensor. Got {}"
|
||
|
.format(type(M)))
|
||
|
|
||
|
if not len(src.shape) == 4:
|
||
|
raise ValueError("Input src must be a BxCxHxW tensor. Got {}"
|
||
|
.format(src.shape))
|
||
|
|
||
|
if not (len(M.shape) == 3 or M.shape[-2:] == (2, 3)):
|
||
|
raise ValueError("Input M must be a Bx2x3 tensor. Got {}"
|
||
|
.format(src.shape))
|
||
|
|
||
|
# we generate a 3x3 transformation matrix from 2x3 affine
|
||
|
M_3x3: torch.Tensor = F.pad(M, [0, 0, 0, 1, 0, 0],
|
||
|
mode="constant", value=0)
|
||
|
M_3x3[:, 2, 2] += 1.0
|
||
|
|
||
|
# launches the warper
|
||
|
h, w = src.shape[-2:]
|
||
|
return transform_warp_impl(src, centers, M_3x3, (h, w), dsize, flags, padding_mode)
|