Consistency fix w.r.t. metrics computed during training.
This commit is contained in:
parent
1406ab6f1c
commit
1c6608208c
1 changed files with 12 additions and 11 deletions
|
@ -183,9 +183,8 @@ def training_loop(
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(f'Distributing across {num_gpus} GPUs...')
|
print(f'Distributing across {num_gpus} GPUs...')
|
||||||
for module in [G, D, G_ema, augment_pipe]:
|
for module in [G, D, G_ema, augment_pipe]:
|
||||||
if module is not None:
|
if module is not None and num_gpus > 1:
|
||||||
for param in misc.params_and_buffers(module):
|
for param in misc.params_and_buffers(module):
|
||||||
if param.numel() > 0 and num_gpus > 1:
|
|
||||||
torch.distributed.broadcast(param, src=0)
|
torch.distributed.broadcast(param, src=0)
|
||||||
|
|
||||||
# Setup training phases.
|
# Setup training phases.
|
||||||
|
@ -281,7 +280,7 @@ def training_loop(
|
||||||
|
|
||||||
# Update weights.
|
# Update weights.
|
||||||
with torch.autograd.profiler.record_function(phase.name + '_opt'):
|
with torch.autograd.profiler.record_function(phase.name + '_opt'):
|
||||||
params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None]
|
params = [param for param in phase.module.parameters() if param.grad is not None]
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
flat = torch.cat([param.grad.flatten() for param in params])
|
flat = torch.cat([param.grad.flatten() for param in params])
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
@ -358,14 +357,16 @@ def training_loop(
|
||||||
snapshot_pkl = None
|
snapshot_pkl = None
|
||||||
snapshot_data = None
|
snapshot_data = None
|
||||||
if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
|
if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
|
||||||
snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
|
snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs))
|
||||||
for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
|
for key, value in snapshot_data.items():
|
||||||
if module is not None:
|
if isinstance(value, torch.nn.Module):
|
||||||
|
value = copy.deepcopy(value).eval().requires_grad_(False)
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)')
|
misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)')
|
||||||
module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
|
for param in misc.params_and_buffers(value):
|
||||||
snapshot_data[name] = module
|
torch.distributed.broadcast(param, src=0)
|
||||||
del module # conserve memory
|
snapshot_data[key] = value.cpu()
|
||||||
|
del value # conserve memory
|
||||||
snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
|
snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
with open(snapshot_pkl, 'wb') as f:
|
with open(snapshot_pkl, 'wb') as f:
|
||||||
|
|
Loading…
Reference in a new issue