154 lines
5.5 KiB
Python
154 lines
5.5 KiB
Python
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|||
|
#
|
|||
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|||
|
# and proprietary rights in and to this software, related documentation
|
|||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|||
|
# distribution of this software and related documentation without an express
|
|||
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|||
|
|
|||
|
"""Main API for computing and reporting quality metrics."""
|
|||
|
|
|||
|
import os
|
|||
|
import time
|
|||
|
import json
|
|||
|
import torch
|
|||
|
import dnnlib
|
|||
|
|
|||
|
from . import metric_utils
|
|||
|
from . import frechet_inception_distance
|
|||
|
from . import kernel_inception_distance
|
|||
|
from . import precision_recall
|
|||
|
from . import perceptual_path_length
|
|||
|
from . import inception_score
|
|||
|
from . import equivariance
|
|||
|
|
|||
|
#----------------------------------------------------------------------------
|
|||
|
|
|||
|
_metric_dict = dict() # name => fn
|
|||
|
|
|||
|
def register_metric(fn):
|
|||
|
assert callable(fn)
|
|||
|
_metric_dict[fn.__name__] = fn
|
|||
|
return fn
|
|||
|
|
|||
|
def is_valid_metric(metric):
|
|||
|
return metric in _metric_dict
|
|||
|
|
|||
|
def list_valid_metrics():
|
|||
|
return list(_metric_dict.keys())
|
|||
|
|
|||
|
#----------------------------------------------------------------------------
|
|||
|
|
|||
|
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
|||
|
assert is_valid_metric(metric)
|
|||
|
opts = metric_utils.MetricOptions(**kwargs)
|
|||
|
|
|||
|
# Calculate.
|
|||
|
start_time = time.time()
|
|||
|
results = _metric_dict[metric](opts)
|
|||
|
total_time = time.time() - start_time
|
|||
|
|
|||
|
# Broadcast results.
|
|||
|
for key, value in list(results.items()):
|
|||
|
if opts.num_gpus > 1:
|
|||
|
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
|||
|
torch.distributed.broadcast(tensor=value, src=0)
|
|||
|
value = float(value.cpu())
|
|||
|
results[key] = value
|
|||
|
|
|||
|
# Decorate with metadata.
|
|||
|
return dnnlib.EasyDict(
|
|||
|
results = dnnlib.EasyDict(results),
|
|||
|
metric = metric,
|
|||
|
total_time = total_time,
|
|||
|
total_time_str = dnnlib.util.format_time(total_time),
|
|||
|
num_gpus = opts.num_gpus,
|
|||
|
)
|
|||
|
|
|||
|
#----------------------------------------------------------------------------
|
|||
|
|
|||
|
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
|||
|
metric = result_dict['metric']
|
|||
|
assert is_valid_metric(metric)
|
|||
|
if run_dir is not None and snapshot_pkl is not None:
|
|||
|
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
|||
|
|
|||
|
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
|||
|
print(jsonl_line)
|
|||
|
if run_dir is not None and os.path.isdir(run_dir):
|
|||
|
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
|||
|
f.write(jsonl_line + '\n')
|
|||
|
|
|||
|
#----------------------------------------------------------------------------
|
|||
|
# Recommended metrics.
|
|||
|
|
|||
|
@register_metric
|
|||
|
def fid50k_full(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
|||
|
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
|||
|
return dict(fid50k_full=fid)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def kid50k_full(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
|||
|
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
|||
|
return dict(kid50k_full=kid)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def pr50k3_full(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
|||
|
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
|||
|
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def ppl2_wend(opts):
|
|||
|
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
|||
|
return dict(ppl2_wend=ppl)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def eqt50k_int(opts):
|
|||
|
opts.G_kwargs.update(force_fp32=True)
|
|||
|
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
|
|||
|
return dict(eqt50k_int=psnr)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def eqt50k_frac(opts):
|
|||
|
opts.G_kwargs.update(force_fp32=True)
|
|||
|
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
|
|||
|
return dict(eqt50k_frac=psnr)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def eqr50k(opts):
|
|||
|
opts.G_kwargs.update(force_fp32=True)
|
|||
|
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
|
|||
|
return dict(eqr50k=psnr)
|
|||
|
|
|||
|
#----------------------------------------------------------------------------
|
|||
|
# Legacy metrics.
|
|||
|
|
|||
|
@register_metric
|
|||
|
def fid50k(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None)
|
|||
|
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
|||
|
return dict(fid50k=fid)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def kid50k(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None)
|
|||
|
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
|||
|
return dict(kid50k=kid)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def pr50k3(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None)
|
|||
|
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
|||
|
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
|||
|
|
|||
|
@register_metric
|
|||
|
def is50k(opts):
|
|||
|
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
|||
|
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
|||
|
return dict(is50k_mean=mean, is50k_std=std)
|
|||
|
|
|||
|
#----------------------------------------------------------------------------
|