diff --git a/training/training_loop.py b/training/training_loop.py index 1fd7970..ddd0c15 100644 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -183,10 +183,9 @@ def training_loop( if rank == 0: print(f'Distributing across {num_gpus} GPUs...') for module in [G, D, G_ema, augment_pipe]: - if module is not None: + if module is not None and num_gpus > 1: for param in misc.params_and_buffers(module): - if param.numel() > 0 and num_gpus > 1: - torch.distributed.broadcast(param, src=0) + torch.distributed.broadcast(param, src=0) # Setup training phases. if rank == 0: @@ -281,7 +280,7 @@ def training_loop( # Update weights. with torch.autograd.profiler.record_function(phase.name + '_opt'): - params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None] + params = [param for param in phase.module.parameters() if param.grad is not None] if len(params) > 0: flat = torch.cat([param.grad.flatten() for param in params]) if num_gpus > 1: @@ -358,14 +357,16 @@ def training_loop( snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): - snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) - for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: - if module is not None: + snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs)) + for key, value in snapshot_data.items(): + if isinstance(value, torch.nn.Module): + value = copy.deepcopy(value).eval().requires_grad_(False) if num_gpus > 1: - misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)') - module = copy.deepcopy(module).eval().requires_grad_(False).cpu() - snapshot_data[name] = module - del module # conserve memory + misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)') + for param in misc.params_and_buffers(value): + torch.distributed.broadcast(param, src=0) + snapshot_data[key] = value.cpu() + del value # conserve memory snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: