# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import warnings from .. import custom_ops from .. import misc from . import upfirdn2d from . import bias_act #---------------------------------------------------------------------------- _plugin = None def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='filtered_lrelu_plugin', sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math'], ) return True def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 #---------------------------------------------------------------------------- def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): r"""Filtered leaky ReLU for a batch of 2D images. Performs the following sequence of operations for each channel: 1. Add channel-specific bias if provided (`b`). 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). 3. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it so that the footprint of all output pixels lies within the input image. 5. Multiply each value by the provided gain factor (`gain`). 6. Apply leaky ReLU activation function to each value. 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking it so that the footprint of all output pixels lies within the input image. 9. Downsample the image by keeping every Nth pixel (`down`). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float16/float64 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. fu: Float32 upsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). fd: Float32 downsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The length of vector must must match the channel dimension of `x`. up: Integer upsampling factor (default: 1). down: Integer downsampling factor. (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). gain: Overall scaling factor for signal magnitude (default: sqrt(2)). slope: Slope on the negative side of leaky ReLU (default: 0.2). clamp: Maximum magnitude for leaky ReLU output (default: None). flip_filter: False = convolution, True = correlation (default: False). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) #---------------------------------------------------------------------------- @misc.profiled_function def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using existing `upfirdn2n()` and `bias_act()` ops. """ assert isinstance(x, torch.Tensor) and x.ndim == 4 fu_w, fu_h = _get_filter_size(fu) fd_w, fd_h = _get_filter_size(fd) if b is not None: assert isinstance(b, torch.Tensor) and b.dtype == x.dtype misc.assert_shape(b, [x.shape[1]]) assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 assert slope == float(slope) and slope >= 0 assert clamp is None or (clamp == float(clamp) and clamp >= 0) # Calculate output size. batch_size, channels, in_h, in_w = x.shape in_dtype = x.dtype out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down # Compute using existing ops. x = bias_act.bias_act(x=x, b=b) # Apply bias. x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Check output shape & dtype. misc.assert_shape(x, [batch_size, channels, out_h, out_w]) assert x.dtype == in_dtype return x #---------------------------------------------------------------------------- _filtered_lrelu_cuda_cache = dict() def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Fast CUDA implementation of `filtered_lrelu()` using custom ops. """ assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 gain = float(gain) assert slope == float(slope) and slope >= 0 slope = float(slope) assert clamp is None or (clamp == float(clamp) and clamp >= 0) clamp = float(clamp if clamp is not None else 'inf') # Lookup from cache. key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) if key in _filtered_lrelu_cuda_cache: return _filtered_lrelu_cuda_cache[key] # Forward op. class FilteredLReluCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). if fu is None: fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) if fd is None: fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert 1 <= fu.ndim <= 2 assert 1 <= fd.ndim <= 2 # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: fu = fu.square()[None] if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: fd = fd.square()[None] # Missing sign input tensor. if si is None: si = torch.empty([0]) # Missing bias tensor. if b is None: b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) # Construct internal sign tensor only if gradients are needed. write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] if any(a < b for a, b in zip(strides[:-1], strides[1:])): warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) # Call C++/Cuda plugin if datatype is supported. if x.dtype in [torch.float16, torch.float32]: if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) else: return_code = -1 # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because # only the bit-packed sign tensor is retained for gradient computation. if return_code < 0: warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Prepare for gradient computation. ctx.save_for_backward(fu, fd, (si if si.numel() else so)) ctx.x_shape = x.shape ctx.y_shape = y.shape ctx.s_ofs = sx, sy return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ fu, fd, si = ctx.saved_tensors _, _, xh, xw = ctx.x_shape _, _, yh, yw = ctx.y_shape sx, sy = ctx.s_ofs dx = None # 0 dfu = None; assert not ctx.needs_input_grad[1] dfd = None; assert not ctx.needs_input_grad[2] db = None # 3 dsi = None; assert not ctx.needs_input_grad[4] dsx = None; assert not ctx.needs_input_grad[5] dsy = None; assert not ctx.needs_input_grad[6] if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: pp = [ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, xw * up - yw * down + px0 - (up - 1), (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, xh * up - yh * down + py0 - (up - 1), ] gg = gain * (up ** 2) / (down ** 2) ff = (not flip_filter) sx = sx - (fu.shape[-1] - 1) + px0 sy = sy - (fu.shape[0] - 1) + py0 dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) if ctx.needs_input_grad[3]: db = dx.sum([0, 2, 3]) return dx, dfu, dfd, db, dsi, dsx, dsy # Add to cache. _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda return FilteredLReluCuda #----------------------------------------------------------------------------