diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py index 388778f..156b6b2 100644 --- a/torch_utils/ops/conv2d_gradfix.py +++ b/torch_utils/ops/conv2d_gradfix.py @@ -11,6 +11,7 @@ arbitrarily high order gradients with zero performance penalty.""" import contextlib import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -20,6 +21,7 @@ import torch enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 @contextlib.contextmanager def no_weight_gradients(disable=True): @@ -48,6 +50,9 @@ def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False + if _use_pytorch_1_11_api: + # The work-around code doesn't work on PyTorch 1.11.0 onwards + return False if input.device.type != 'cuda': return False return True diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index 979ee83..441b379 100644 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -12,6 +12,7 @@ Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -20,6 +21,7 @@ import torch #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 #---------------------------------------------------------------------------- @@ -56,7 +58,11 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid