2021-10-13 12:00:23 +02:00
|
|
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2021-10-07 11:55:26 +02:00
|
|
|
#
|
|
|
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
|
|
# and proprietary rights in and to this software, related documentation
|
|
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
|
|
# distribution of this software and related documentation without an express
|
|
|
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
|
|
|
|
|
|
"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
|
|
|
|
"Alias-Free Generative Adversarial Networks"."""
|
|
|
|
|
|
|
|
import copy
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.fft
|
|
|
|
from torch_utils.ops import upfirdn2d
|
|
|
|
from . import metric_utils
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Utilities.
|
|
|
|
|
|
|
|
def sinc(x):
|
|
|
|
y = (x * np.pi).abs()
|
|
|
|
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
|
|
|
|
return torch.where(y < 1e-30, torch.ones_like(x), z)
|
|
|
|
|
|
|
|
def lanczos_window(x, a):
|
|
|
|
x = x.abs() / a
|
|
|
|
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
|
|
|
|
|
|
|
|
def rotation_matrix(angle):
|
|
|
|
angle = torch.as_tensor(angle).to(torch.float32)
|
|
|
|
mat = torch.eye(3, device=angle.device)
|
|
|
|
mat[0, 0] = angle.cos()
|
|
|
|
mat[0, 1] = angle.sin()
|
|
|
|
mat[1, 0] = -angle.sin()
|
|
|
|
mat[1, 1] = angle.cos()
|
|
|
|
return mat
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Apply integer translation to a batch of 2D images. Corresponds to the
|
|
|
|
# operator T_x in Appendix E.1.
|
|
|
|
|
|
|
|
def apply_integer_translation(x, tx, ty):
|
|
|
|
_N, _C, H, W = x.shape
|
|
|
|
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
|
|
|
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
|
|
|
ix = tx.round().to(torch.int64)
|
|
|
|
iy = ty.round().to(torch.int64)
|
|
|
|
|
|
|
|
z = torch.zeros_like(x)
|
|
|
|
m = torch.zeros_like(x)
|
|
|
|
if abs(ix) < W and abs(iy) < H:
|
|
|
|
y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
|
|
|
|
z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
|
|
|
|
m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
|
|
|
|
return z, m
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Apply integer translation to a batch of 2D images. Corresponds to the
|
|
|
|
# operator T_x in Appendix E.2.
|
|
|
|
|
|
|
|
def apply_fractional_translation(x, tx, ty, a=3):
|
|
|
|
_N, _C, H, W = x.shape
|
|
|
|
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
|
|
|
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
|
|
|
ix = tx.floor().to(torch.int64)
|
|
|
|
iy = ty.floor().to(torch.int64)
|
|
|
|
fx = tx - ix
|
|
|
|
fy = ty - iy
|
|
|
|
b = a - 1
|
|
|
|
|
|
|
|
z = torch.zeros_like(x)
|
|
|
|
zx0 = max(ix - b, 0)
|
|
|
|
zy0 = max(iy - b, 0)
|
|
|
|
zx1 = min(ix + a, 0) + W
|
|
|
|
zy1 = min(iy + a, 0) + H
|
|
|
|
if zx0 < zx1 and zy0 < zy1:
|
|
|
|
taps = torch.arange(a * 2, device=x.device) - b
|
|
|
|
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
|
|
|
|
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
|
|
|
|
y = x
|
|
|
|
y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
|
|
|
|
y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
|
|
|
|
y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
|
|
|
|
z[:, :, zy0:zy1, zx0:zx1] = y
|
|
|
|
|
|
|
|
m = torch.zeros_like(x)
|
|
|
|
mx0 = max(ix + a, 0)
|
|
|
|
my0 = max(iy + a, 0)
|
|
|
|
mx1 = min(ix - b, 0) + W
|
|
|
|
my1 = min(iy - b, 0) + H
|
|
|
|
if mx0 < mx1 and my0 < my1:
|
|
|
|
m[:, :, my0:my1, mx0:mx1] = 1
|
|
|
|
return z, m
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Construct an oriented low-pass filter that applies the appropriate
|
|
|
|
# bandlimit with respect to the input and output of the given affine 2D
|
|
|
|
# image transformation.
|
|
|
|
|
|
|
|
def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
|
|
|
|
assert a <= amax < aflt
|
|
|
|
mat = torch.as_tensor(mat).to(torch.float32)
|
|
|
|
|
|
|
|
# Construct 2D filter taps in input & output coordinate spaces.
|
|
|
|
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
|
|
|
|
yi, xi = torch.meshgrid(taps, taps)
|
|
|
|
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
|
|
|
|
|
|
|
|
# Convolution of two oriented 2D sinc filters.
|
|
|
|
fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
|
|
|
|
fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
|
|
|
|
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
|
|
|
|
|
|
|
|
# Convolution of two oriented 2D Lanczos windows.
|
|
|
|
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
|
|
|
|
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
|
|
|
|
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
|
|
|
|
|
|
|
|
# Construct windowed FIR filter.
|
|
|
|
f = f * w
|
|
|
|
|
|
|
|
# Finalize.
|
|
|
|
c = (aflt - amax) * up
|
|
|
|
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
|
|
|
|
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
|
|
|
|
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
|
|
|
|
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
|
|
|
|
return f
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Apply the given affine transformation to a batch of 2D images.
|
|
|
|
|
|
|
|
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
|
|
|
|
_N, _C, H, W = x.shape
|
|
|
|
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
|
|
|
|
|
|
|
|
# Construct filter.
|
|
|
|
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
|
|
|
|
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
|
|
|
|
p = f.shape[0] // 2
|
|
|
|
|
|
|
|
# Construct sampling grid.
|
|
|
|
theta = mat.inverse()
|
|
|
|
theta[:2, 2] *= 2
|
|
|
|
theta[0, 2] += 1 / up / W
|
|
|
|
theta[1, 2] += 1 / up / H
|
|
|
|
theta[0, :] *= W / (W + p / up * 2)
|
|
|
|
theta[1, :] *= H / (H + p / up * 2)
|
|
|
|
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
|
|
|
|
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
|
|
|
|
|
|
|
|
# Resample image.
|
|
|
|
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
|
|
|
|
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
|
|
|
|
|
|
|
|
# Form mask.
|
|
|
|
m = torch.zeros_like(y)
|
|
|
|
c = p * 2 + 1
|
|
|
|
m[:, :, c:-c, c:-c] = 1
|
|
|
|
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
|
|
|
|
return z, m
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Apply fractional rotation to a batch of 2D images. Corresponds to the
|
|
|
|
# operator R_\alpha in Appendix E.3.
|
|
|
|
|
|
|
|
def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
|
|
|
|
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
|
|
|
mat = rotation_matrix(angle)
|
|
|
|
return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Modify the frequency content of a batch of 2D images as if they had undergo
|
|
|
|
# fractional rotation -- but without actually rotating them. Corresponds to
|
|
|
|
# the operator R^*_\alpha in Appendix E.3.
|
|
|
|
|
|
|
|
def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
|
|
|
|
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
|
|
|
mat = rotation_matrix(-angle)
|
|
|
|
f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
|
|
|
|
y = upfirdn2d.filter2d(x=x, f=f)
|
|
|
|
m = torch.zeros_like(y)
|
|
|
|
c = f.shape[0] // 2
|
|
|
|
m[:, :, c:-c, c:-c] = 1
|
|
|
|
return y, m
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
# Compute the selected equivariance metrics for the given generator.
|
|
|
|
|
|
|
|
def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
|
|
|
|
assert compute_eqt_int or compute_eqt_frac or compute_eqr
|
|
|
|
|
|
|
|
# Setup generator and labels.
|
|
|
|
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
|
|
|
I = torch.eye(3, device=opts.device)
|
|
|
|
M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
|
|
|
|
if M is None:
|
|
|
|
raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
|
|
|
|
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
|
|
|
|
|
|
|
# Sampling loop.
|
|
|
|
sums = None
|
|
|
|
progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
|
|
|
|
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
|
|
|
progress.update(batch_start)
|
|
|
|
s = []
|
|
|
|
|
|
|
|
# Randomize noise buffers, if any.
|
|
|
|
for name, buf in G.named_buffers():
|
|
|
|
if name.endswith('.noise_const'):
|
|
|
|
buf.copy_(torch.randn_like(buf))
|
|
|
|
|
|
|
|
# Run mapping network.
|
|
|
|
z = torch.randn([batch_size, G.z_dim], device=opts.device)
|
|
|
|
c = next(c_iter)
|
|
|
|
ws = G.mapping(z=z, c=c)
|
|
|
|
|
|
|
|
# Generate reference image.
|
|
|
|
M[:] = I
|
|
|
|
orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
|
|
|
|
|
|
|
# Integer translation (EQ-T).
|
|
|
|
if compute_eqt_int:
|
|
|
|
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
|
|
|
t = (t * G.img_resolution).round() / G.img_resolution
|
|
|
|
M[:] = I
|
|
|
|
M[:2, 2] = -t
|
|
|
|
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
|
|
|
ref, mask = apply_integer_translation(orig, t[0], t[1])
|
|
|
|
s += [(ref - img).square() * mask, mask]
|
|
|
|
|
|
|
|
# Fractional translation (EQ-T_frac).
|
|
|
|
if compute_eqt_frac:
|
|
|
|
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
|
|
|
M[:] = I
|
|
|
|
M[:2, 2] = -t
|
|
|
|
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
|
|
|
ref, mask = apply_fractional_translation(orig, t[0], t[1])
|
|
|
|
s += [(ref - img).square() * mask, mask]
|
|
|
|
|
|
|
|
# Rotation (EQ-R).
|
|
|
|
if compute_eqr:
|
|
|
|
angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
|
|
|
|
M[:] = rotation_matrix(-angle)
|
|
|
|
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
|
|
|
ref, ref_mask = apply_fractional_rotation(orig, angle)
|
|
|
|
pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
|
|
|
|
mask = ref_mask * pseudo_mask
|
|
|
|
s += [(ref - pseudo).square() * mask, mask]
|
|
|
|
|
|
|
|
# Accumulate results.
|
|
|
|
s = torch.stack([x.to(torch.float64).sum() for x in s])
|
|
|
|
sums = sums + s if sums is not None else s
|
|
|
|
progress.update(num_samples)
|
|
|
|
|
|
|
|
# Compute PSNRs.
|
|
|
|
if opts.num_gpus > 1:
|
|
|
|
torch.distributed.all_reduce(sums)
|
|
|
|
sums = sums.cpu()
|
|
|
|
mses = sums[0::2] / sums[1::2]
|
|
|
|
psnrs = np.log10(2) * 20 - mses.log10() * 10
|
|
|
|
psnrs = tuple(psnrs.numpy())
|
|
|
|
return psnrs[0] if len(psnrs) == 1 else psnrs
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|