2021-10-13 12:00:23 +02:00
|
|
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2021-10-07 11:55:26 +02:00
|
|
|
#
|
|
|
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
|
|
# and proprietary rights in and to this software, related documentation
|
|
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
|
|
# distribution of this software and related documentation without an express
|
|
|
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
|
|
|
|
|
|
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
def fma(a, b, c): # => a * b + c
|
|
|
|
return _FusedMultiplyAdd.apply(a, b, c)
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
|
|
|
out = torch.addcmul(c, a, b)
|
|
|
|
ctx.save_for_backward(a, b)
|
|
|
|
ctx.c_shape = c.shape
|
|
|
|
return out
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, dout): # pylint: disable=arguments-differ
|
|
|
|
a, b = ctx.saved_tensors
|
|
|
|
c_shape = ctx.c_shape
|
|
|
|
da = None
|
|
|
|
db = None
|
|
|
|
dc = None
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
|
|
da = _unbroadcast(dout * b, a.shape)
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[1]:
|
|
|
|
db = _unbroadcast(dout * a, b.shape)
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[2]:
|
|
|
|
dc = _unbroadcast(dout, c_shape)
|
|
|
|
|
|
|
|
return da, db, dc
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
def _unbroadcast(x, shape):
|
|
|
|
extra_dims = x.ndim - len(shape)
|
|
|
|
assert extra_dims >= 0
|
|
|
|
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
|
|
|
if len(dim):
|
|
|
|
x = x.sum(dim=dim, keepdim=True)
|
|
|
|
if extra_dims:
|
|
|
|
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
|
|
|
assert x.shape == shape
|
|
|
|
return x
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------------
|