autocast
This commit is contained in:
parent
893c351761
commit
62308078cf
1 changed files with 52 additions and 41 deletions
|
@ -9,6 +9,8 @@ from einops import rearrange
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
import time
|
import time
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
@ -178,6 +180,13 @@ def main():
|
||||||
default=42,
|
default=42,
|
||||||
help="the seed (for reproducible sampling)",
|
help="the seed (for reproducible sampling)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
help="evaluate at this precision",
|
||||||
|
choices=["full", "autocast"],
|
||||||
|
default="autocast"
|
||||||
|
)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
|
@ -217,7 +226,9 @@ def main():
|
||||||
if opt.fixed_code:
|
if opt.fixed_code:
|
||||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||||
|
|
||||||
|
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
|
|
Loading…
Reference in a new issue