eval stuff

This commit is contained in:
rromb 2022-06-09 10:56:34 +02:00
parent 6dc939330d
commit 9cda7feafa
9 changed files with 2365 additions and 30 deletions

View file

@ -0,0 +1,676 @@
import argparse
import io
import os
import random
import warnings
import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool
from typing import Iterable, Optional, Tuple
import yaml
import numpy as np
import requests
import tensorflow.compat.v1 as tf
from scipy import linalg
from tqdm.auto import tqdm
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
FID_POOL_NAME = "pool_3:0"
FID_SPATIAL_NAME = "mixed_6/conv:0"
REQUIREMENTS = f"This script has the following requirements: \n" \
'tensorflow-gpu>=2.0' + "\n" + 'scipy' + "\n" + "requests" + "\n" + "tqdm"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ref_batch", help="path to reference batch npz file")
parser.add_argument("--sample_batch", help="path to sample batch npz file")
args = parser.parse_args()
config = tf.ConfigProto(
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
)
config.gpu_options.allow_growth = True
evaluator = Evaluator(tf.Session(config=config))
print("warming up TensorFlow...")
# This will cause TF to print a bunch of verbose stuff now rather
# than after the next print(), to help prevent confusion.
evaluator.warmup()
print("computing reference batch activations...")
ref_acts = evaluator.read_activations(args.ref_batch)
print("computing/reading reference batch statistics...")
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
print("computing sample batch activations...")
sample_acts = evaluator.read_activations(args.sample_batch)
print("computing/reading sample batch statistics...")
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
print("Computing evaluations...")
is_ = evaluator.compute_inception_score(sample_acts[0])
print("Inception Score:", is_)
fid = sample_stats.frechet_distance(ref_stats)
print("FID:", fid)
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
print("sFID:", sfid)
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
print("Precision:", prec)
print("Recall:", recall)
savepath = '/'.join(args.sample_batch.split('/')[:-1])
results_file = os.path.join(savepath,'evaluation_metrics.yaml')
print(f'Saving evaluation results to "{results_file}"')
results = {
'IS': is_,
'FID': fid,
'sFID': sfid,
'Precision:':prec,
'Recall': recall
}
with open(results_file, 'w') as f:
yaml.dump(results, f, default_flow_style=False)
class InvalidFIDException(Exception):
pass
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
self.mu = mu
self.sigma = sigma
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
class Evaluator:
def __init__(
self,
session,
batch_size=64,
softmax_batch_size=512,
):
self.sess = session
self.batch_size = batch_size
self.softmax_batch_size = softmax_batch_size
self.manifold_estimator = ManifoldEstimator(session)
with self.sess.graph.as_default():
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
self.softmax = _create_softmax_graph(self.softmax_input)
def warmup(self):
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
with open_npz_array(npz_path, "arr_0") as reader:
return self.compute_activations(reader.read_batches(self.batch_size))
def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute image features for downstream evals.
:param batches: a iterator over NHWC numpy arrays in [0, 255].
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
dimension. The tuple is (pool_3, spatial).
"""
preds = []
spatial_preds = []
it = batches if silent else tqdm(batches)
for batch in it:
batch = batch.astype(np.float32)
pred, spatial_pred = self.sess.run(
[self.pool_features, self.spatial_features], {self.image_input: batch}
)
preds.append(pred.reshape([pred.shape[0], -1]))
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
return (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
)
def read_statistics(
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
) -> Tuple[FIDStatistics, FIDStatistics]:
obj = np.load(npz_path)
if "mu" in list(obj.keys()):
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
obj["mu_s"], obj["sigma_s"]
)
return tuple(self.compute_statistics(x) for x in activations)
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return FIDStatistics(mu, sigma)
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
softmax_out = []
for i in range(0, len(activations), self.softmax_batch_size):
acts = activations[i : i + self.softmax_batch_size]
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
preds = np.concatenate(softmax_out, axis=0)
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))
def compute_prec_recall(
self, activations_ref: np.ndarray, activations_sample: np.ndarray
) -> Tuple[float, float]:
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
pr = self.manifold_estimator.evaluate_pr(
activations_ref, radii_1, activations_sample, radii_2
)
return (float(pr[0][0]), float(pr[1][0]))
class ManifoldEstimator:
"""
A helper for comparing manifolds of feature vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
"""
def __init__(
self,
session,
row_batch_size=10000,
col_batch_size=10000,
nhood_sizes=(3,),
clamp_to_percentile=None,
eps=1e-5,
):
"""
Estimate the manifold of given feature vectors.
:param session: the TensorFlow session.
:param row_batch_size: row batch size to compute pairwise distances
(parameter to trade-off between memory usage and performance).
:param col_batch_size: column batch size to compute pairwise distances.
:param nhood_sizes: number of neighbors used to estimate the manifold.
:param clamp_to_percentile: prune hyperspheres that have radius larger than
the given percentile.
:param eps: small number for numerical stability.
"""
self.distance_block = DistanceBlock(session)
self.row_batch_size = row_batch_size
self.col_batch_size = col_batch_size
self.nhood_sizes = nhood_sizes
self.num_nhoods = len(nhood_sizes)
self.clamp_to_percentile = clamp_to_percentile
self.eps = eps
def warmup(self):
feats, radii = (
np.zeros([1, 2048], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32),
)
self.evaluate_pr(feats, radii, feats, radii)
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
num_images = len(features)
# Estimate manifold of features by calculating distances to k-NN of each sample.
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
for begin1 in range(0, num_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_images)
row_batch = features[begin1:end1]
for begin2 in range(0, num_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_images)
col_batch = features[begin2:end2]
# Compute distances between batches.
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(row_batch, col_batch)
# Find the k-nearest neighbor from the current batch.
radii[begin1:end1, :] = np.concatenate(
[
x[:, self.nhood_sizes]
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
],
axis=0,
)
if self.clamp_to_percentile is not None:
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
radii[radii > max_distances] = 0
return radii
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
"""
Evaluate if new feature vectors are at the manifold.
"""
num_eval_images = eval_features.shape[0]
num_ref_images = radii.shape[0]
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
for begin1 in range(0, num_eval_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_eval_images)
feature_batch = eval_features[begin1:end1]
for begin2 in range(0, num_ref_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_ref_images)
ref_batch = features[begin2:end2]
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
# If a feature vector is inside a hypersphere of some reference sample, then
# the new sample lies at the estimated manifold.
# The radii of the hyperspheres are determined from distances of neighborhood size k.
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
max_realism_score[begin1:end1] = np.max(
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
)
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
return {
"fraction": float(np.mean(batch_predictions)),
"batch_predictions": batch_predictions,
"max_realisim_score": max_realism_score,
"nearest_indices": nearest_indices,
}
def evaluate_pr(
self,
features_1: np.ndarray,
radii_1: np.ndarray,
features_2: np.ndarray,
radii_2: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Evaluate precision and recall efficiently.
:param features_1: [N1 x D] feature vectors for reference batch.
:param radii_1: [N1 x K1] radii for reference vectors.
:param features_2: [N2 x D] feature vectors for the other batch.
:param radii_2: [N x K2] radii for other vectors.
:return: a tuple of arrays for (precision, recall):
- precision: an np.ndarray of length K1
- recall: an np.ndarray of length K2
"""
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
for begin_1 in range(0, len(features_1), self.row_batch_size):
end_1 = begin_1 + self.row_batch_size
batch_1 = features_1[begin_1:end_1]
for begin_2 in range(0, len(features_2), self.col_batch_size):
end_2 = begin_2 + self.col_batch_size
batch_2 = features_2[begin_2:end_2]
batch_1_in, batch_2_in = self.distance_block.less_thans(
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
)
features_1_status[begin_1:end_1] |= batch_1_in
features_2_status[begin_2:end_2] |= batch_2_in
return (
np.mean(features_2_status.astype(np.float64), axis=0),
np.mean(features_1_status.astype(np.float64), axis=0),
)
class DistanceBlock:
"""
Calculate pairwise distances between vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
"""
def __init__(self, session):
self.session = session
# Initialize TF graph to calculate pairwise distances.
with session.graph.as_default():
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
distance_block_16 = _batch_pairwise_distances(
tf.cast(self._features_batch1, tf.float16),
tf.cast(self._features_batch2, tf.float16),
)
self.distance_block = tf.cond(
tf.reduce_all(tf.math.is_finite(distance_block_16)),
lambda: tf.cast(distance_block_16, tf.float32),
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
)
# Extra logic for less thans.
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
def pairwise_distances(self, U, V):
"""
Evaluate pairwise distances between two batches of feature vectors.
"""
return self.session.run(
self.distance_block,
feed_dict={self._features_batch1: U, self._features_batch2: V},
)
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
return self.session.run(
[self._batch_1_in, self._batch_2_in],
feed_dict={
self._features_batch1: batch_1,
self._features_batch2: batch_2,
self._radii1: radii_1,
self._radii2: radii_2,
},
)
def _batch_pairwise_distances(U, V):
"""
Compute pairwise distances between two batches of feature vectors.
"""
with tf.variable_scope("pairwise_dist_block"):
# Squared norms of each row in U and V.
norm_u = tf.reduce_sum(tf.square(U), 1)
norm_v = tf.reduce_sum(tf.square(V), 1)
# norm_u as a column and norm_v as a row vectors.
norm_u = tf.reshape(norm_u, [-1, 1])
norm_v = tf.reshape(norm_v, [1, -1])
# Pairwise squared Euclidean distances.
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
return D
class NpzArrayReader(ABC):
@abstractmethod
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
pass
@abstractmethod
def remaining(self) -> int:
pass
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
def gen_fn():
while True:
batch = self.read_batch(batch_size)
if batch is None:
break
yield batch
rem = self.remaining()
num_batches = rem // batch_size + int(rem % batch_size != 0)
return BatchIterator(gen_fn, num_batches)
class BatchIterator:
def __init__(self, gen_fn, length):
self.gen_fn = gen_fn
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen_fn()
class StreamingNpzArrayReader(NpzArrayReader):
def __init__(self, arr_f, shape, dtype):
self.arr_f = arr_f
self.shape = shape
self.dtype = dtype
self.idx = 0
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.shape[0]:
return None
bs = min(batch_size, self.shape[0] - self.idx)
self.idx += bs
if self.dtype.itemsize == 0:
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
read_count = bs * np.prod(self.shape[1:])
read_size = int(read_count * self.dtype.itemsize)
data = _read_bytes(self.arr_f, read_size, "array data")
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
def remaining(self) -> int:
return max(0, self.shape[0] - self.idx)
class MemoryNpzArrayReader(NpzArrayReader):
def __init__(self, arr):
self.arr = arr
self.idx = 0
@classmethod
def load(cls, path: str, arr_name: str):
with open(path, "rb") as f:
arr = np.load(f)[arr_name]
return cls(arr)
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.arr.shape[0]:
return None
res = self.arr[self.idx : self.idx + batch_size]
self.idx += batch_size
return res
def remaining(self) -> int:
return max(0, self.arr.shape[0] - self.idx)
@contextmanager
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
with _open_npy_file(path, arr_name) as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
yield MemoryNpzArrayReader.load(path, arr_name)
return
shape, fortran, dtype = header
if fortran or dtype.hasobject:
yield MemoryNpzArrayReader.load(path, arr_name)
else:
yield StreamingNpzArrayReader(arr_f, shape, dtype)
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
data = bytes()
while True:
# io files (default in python3) return None or raise on
# would-block, python2 file will truncate, probably nothing can be
# done about that. note that regular files can't be non-blocking
try:
r = fp.read(size - len(data))
data += r
if len(r) == 0 or len(data) == size:
break
except io.BlockingIOError:
pass
if len(data) != size:
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
return data
@contextmanager
def _open_npy_file(path: str, arr_name: str):
with open(path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
if f"{arr_name}.npy" not in zip_f.namelist():
raise ValueError(f"missing {arr_name} in npz file")
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
yield arr_f
def _download_inception_model():
if os.path.exists(INCEPTION_V3_PATH):
return
print("downloading InceptionV3 model...")
with requests.get(INCEPTION_V3_URL, stream=True) as r:
r.raise_for_status()
tmp_path = INCEPTION_V3_PATH + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
f.write(chunk)
os.rename(tmp_path, INCEPTION_V3_PATH)
def _create_feature_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
pool3, spatial = tf.import_graph_def(
graph_def,
input_map={f"ExpandDims:0": input_batch},
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
name=prefix,
)
_update_shapes(pool3)
spatial = spatial[..., :7]
return pool3, spatial
def _create_softmax_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
(matmul,) = tf.import_graph_def(
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
)
w = matmul.inputs[1]
logits = tf.matmul(input_batch, w)
return tf.nn.softmax(logits)
def _update_shapes(pool3):
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
ops = pool3.graph.get_operations()
for op in ops:
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None: # pylint: disable=protected-access
# shape = [s.value for s in shape] TF 1.x
shape = [s for s in shape] # TF 2.x
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
def _numpy_partition(arr, kth, **kwargs):
num_workers = min(cpu_count(), len(arr))
chunk_size = len(arr) // num_workers
extra = len(arr) % num_workers
start_idx = 0
batches = []
for i in range(num_workers):
size = chunk_size + (1 if i < extra else 0)
batches.append(arr[start_idx : start_idx + size])
start_idx += size
with ThreadPool(num_workers) as pool:
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
if __name__ == "__main__":
print(REQUIREMENTS)
main()

View file

@ -0,0 +1,630 @@
import argparse
import glob
import os
from tqdm import tqdm
from collections import namedtuple
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
from ldm.modules.evaluate.ssim import ssim
transform = transforms.Compose([transforms.ToTensor()])
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view(
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
)
return in_feat / (norm_factor.expand_as(in_feat) + eps)
def cos_sim(in0, in1):
in0_norm = normalize_tensor(in0)
in1_norm = normalize_tensor(in1)
N = in0.size()[0]
X = in0.size()[2]
Y = in0.size()[3]
return torch.mean(
torch.mean(
torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2
).view(N, 1, 1, Y),
dim=3,
).view(N)
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = models.squeezenet1_1(
pretrained=pretrained
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple(
"SqueezeOutputs",
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
)
out = vgg_outputs(
h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7
)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = models.alexnet(
pretrained=pretrained
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple(
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
)
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs",
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if num == 18:
self.net = models.resnet18(pretrained=pretrained)
elif num == 34:
self.net = models.resnet34(pretrained=pretrained)
elif num == 50:
self.net = models.resnet50(pretrained=pretrained)
elif num == 101:
self.net = models.resnet101(pretrained=pretrained)
elif num == 152:
self.net = models.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple(
"Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]
)
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
# Off-the-shelf deep network
class PNet(torch.nn.Module):
"""Pre-trained network with all channels equally weighted by default"""
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
super(PNet, self).__init__()
self.use_gpu = use_gpu
self.pnet_type = pnet_type
self.pnet_rand = pnet_rand
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
if self.pnet_type in ["vgg", "vgg16"]:
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
elif self.pnet_type == "alex":
self.net = alexnet(
pretrained=not self.pnet_rand, requires_grad=False
)
elif self.pnet_type[:-2] == "resnet":
self.net = resnet(
pretrained=not self.pnet_rand,
requires_grad=False,
num=int(self.pnet_type[-2:]),
)
elif self.pnet_type == "squeeze":
self.net = squeezenet(
pretrained=not self.pnet_rand, requires_grad=False
)
self.L = self.net.N_slices
if use_gpu:
self.net.cuda()
self.shift = self.shift.cuda()
self.scale = self.scale.cuda()
def forward(self, in0, in1, retPerLayer=False):
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
outs0 = self.net.forward(in0_sc)
outs1 = self.net.forward(in1_sc)
if retPerLayer:
all_scores = []
for (kk, out0) in enumerate(outs0):
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
if kk == 0:
val = 1.0 * cur_score
else:
val = val + cur_score
if retPerLayer:
all_scores += [cur_score]
if retPerLayer:
return (val, all_scores)
else:
return val
# The SSIM metric
def ssim_metric(img1, img2, mask=None):
return ssim(img1, img2, mask=mask, size_average=False)
# The PSNR metric
def psnr(img1, img2, mask=None,reshape=False):
b = img1.size(0)
if not (mask is None):
b = img1.size(0)
mse_err = (img1 - img2).pow(2) * mask
if reshape:
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
)
else:
mse_err = mse_err.view(b, -1).sum(dim=1) / (
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
)
else:
if reshape:
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
else:
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
psnr = 10 * (1 / mse_err).log10()
return psnr
# The perceptual similarity metric
def perceptual_sim(img1, img2, vgg16):
# First extract features
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
return dist
def load_img(img_name, size=None):
try:
img = Image.open(img_name)
if type(size) == int:
img = img.resize((size, size))
elif size is not None:
img = img.resize((size[1], size[0]))
img = transform(img).cuda()
img = img.unsqueeze(0)
except Exception as e:
print("Failed at loading %s " % img_name)
print(e)
img = torch.zeros(1, 3, 256, 256).cuda()
raise
return img
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
folders = os.listdir(folder)
for i, f in tqdm(enumerate(sorted(folders))):
pred_imgs = glob.glob(folder + f + "/" + pred_img)
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
assert len(tgt_imgs) == 1
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [
min(values_percsim[2 * i], values_percsim[2 * i + 1])
]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list,
take_every_other,
simple_format=True):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
equal_count = 0
ambig_count = 0
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
assert len(pred_imgs)>0
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
if psnr_sim != np.float("inf"):
values_psnr += [psnr_sim]
else:
if torch.allclose(p_img, t_img):
equal_count += 1
print("{} equal src and wrp images.".format(equal_count))
else:
ambig_count += 1
print("{} ambiguous src and wrp images.".format(ambig_count))
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [
min(values_percsim[2 * i], values_percsim[2 * i + 1])
]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
if simple_format:
# just to make yaml formatting readable
return {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
}
else:
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list,
take_every_other, resize=False):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
individual_percsim = []
individual_ssim = []
individual_psnr = []
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
assert False
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
sample_percsim = list()
sample_ssim = list()
sample_psnr = list()
for p_img in pred_imgs:
if resize:
t_img = load_img(tgt_imgs[0], size=(256,256))
else:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
sample_percsim.append(t_perc_sim)
perc_sim = min(perc_sim, t_perc_sim)
t_ssim = ssim_metric(p_img, t_img).item()
sample_ssim.append(t_ssim)
ssim_sim = max(ssim_sim, t_ssim)
t_psnr = psnr(p_img, t_img).item()
sample_psnr.append(t_psnr)
psnr_sim = max(psnr_sim, t_psnr)
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
individual_percsim.append(sample_percsim)
individual_ssim.append(sample_ssim)
individual_psnr.append(sample_psnr)
if take_every_other:
assert False, "Do this later, after specifying topk to get proper results"
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [
min(values_percsim[2 * i], values_percsim[2 * i + 1])
]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
individual_percsim = np.array(individual_percsim)
individual_psnr = np.array(individual_psnr)
individual_ssim = np.array(individual_ssim)
return {
"avg_of_best": {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
},
"individual": {
"PSIM": individual_percsim,
"PSNR": individual_psnr,
"SSIM": individual_ssim,
}
}
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--folder", type=str, default="")
args.add_argument("--pred_image", type=str, default="")
args.add_argument("--target_image", type=str, default="")
args.add_argument("--take_every_other", action="store_true", default=False)
args.add_argument("--output_file", type=str, default="")
opts = args.parse_args()
folder = opts.folder
pred_img = opts.pred_image
tgt_img = opts.target_image
results = compute_perceptual_similarity(
folder, pred_img, tgt_img, opts.take_every_other
)
f = open(opts.output_file, 'w')
for key in results:
print("%s for %s: \n" % (key, opts.folder))
print(
"\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
)
f.write("%s for %s: \n" % (key, opts.folder))
f.write(
"\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
)
f.close()

View file

@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
FVD is a metric for the quality of video generation models. It is inspired by
the FID (Frechet Inception Distance) used for images, but uses a different
embedding to be better suitable for videos.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
def preprocess(videos, target_resolution):
"""Runs some preprocessing on the videos for I3D model.
Args:
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
preprocessed. We don't care about the specific dtype of the videos, it can
be anything that tf.image.resize_bilinear accepts. Values are expected to
be in the range 0-255.
target_resolution: (width, height): target video resolution
Returns:
videos: <float32>[batch_size, num_frames, height, width, depth]
"""
videos_shape = list(videos.shape)
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
output_videos = tf.reshape(resized_videos, target_shape)
scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
return scaled_videos
def _is_in_graph(tensor_name):
"""Checks whether a given tensor does exists in the graph."""
try:
tf.get_default_graph().get_tensor_by_name(tensor_name)
except KeyError:
return False
return True
def create_id3_embedding(videos,warmup=False,batch_size=16):
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
first call.
Args:
videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
Expected range is [-1, 1].
Returns:
embedding: <float32>[batch_size, embedding_size]. embedding_size depends
on the model used.
Raises:
ValueError: when a provided embedding_layer is not supported.
"""
# batch_size = 16
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
# Making sure that we import the graph separately for
# each different input video tensor.
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
videos.name).replace(":", "_")
assert_ops = [
tf.Assert(
tf.reduce_max(videos) <= 1.001,
["max value in frame is > 1", videos]),
tf.Assert(
tf.reduce_min(videos) >= -1.001,
["min value in frame is < -1", videos]),
tf.assert_equal(
tf.shape(videos)[0],
batch_size, ["invalid frame batch size: ",
tf.shape(videos)],
summarize=6),
]
with tf.control_dependencies(assert_ops):
videos = tf.identity(videos)
module_scope = "%s_apply_default/" % module_name
# To check whether the module has already been loaded into the graph, we look
# for a given tensor name. If this tensor name exists, we assume the function
# has been called before and the graph was imported. Otherwise we import it.
# Note: in theory, the tensor could exist, but have wrong shapes.
# This will happen if create_id3_embedding is called with a frames_placehoder
# of wrong size/batch size, because even though that will throw a tf.Assert
# on graph-execution time, it will insert the tensor (with wrong shape) into
# the graph. This is why we need the following assert.
if warmup:
video_batch_size = int(videos.shape[0])
assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
if not _is_in_graph(tensor_name):
i3d_model = hub.Module(module_spec, name=module_name)
i3d_model(videos)
# gets the kinetics-i3d-400-logits layer
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
return tensor
def calculate_fvd(real_activations,
generated_activations):
"""Returns a list of ops that compute metrics as funcs of activations.
Args:
real_activations: <float32>[num_samples, embedding_size]
generated_activations: <float32>[num_samples, embedding_size]
Returns:
A scalar that contains the requested FVD.
"""
return tfgan.eval.frechet_classifier_distance_from_activations(
real_activations, generated_activations)

View file

@ -0,0 +1,124 @@
# MIT Licence
# Methods to predict the SSIM, taken from
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
from math import exp
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
for x in range(window_size)
]
)
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
)
return window
def _ssim(
img1, img2, window, window_size, channel, mask=None, size_average=True
):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
- mu1_sq
)
sigma2_sq = (
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
- mu2_sq
)
sigma12 = (
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
- mu1_mu2
)
C1 = (0.01) ** 2
C2 = (0.03) ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
if not (mask is None):
b = mask.size(0)
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
dim=1
).clamp(min=1)
return ssim_map
import pdb
pdb.set_trace
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2, mask=None):
(_, channel, _, _) = img1.size()
if (
channel == self.channel
and self.window.data.type() == img1.data.type()
):
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(
img1,
img2,
window,
self.window_size,
channel,
mask,
self.size_average,
)
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, mask, size_average)

View file

@ -0,0 +1,294 @@
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
import os
import numpy as np
import io
import re
import requests
import html
import hashlib
import urllib
import urllib.request
import scipy.linalg
import multiprocessing as mp
import glob
from tqdm import tqdm
from typing import Any, List, Tuple, Union, Dict, Callable
from torchvision.io import read_video
import torch; torch.set_grad_enabled(False)
from einops import rearrange
from nitro.util import isvideo
def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float:
print('Calculate frechet distance...')
m = np.square(mu_sample - mu_ref).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
return float(fid)
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match('^[a-z]+://', url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith('file://'):
filename = urllib.parse.urlparse(url).path
if re.match(r'^/[a-zA-Z]:', filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive download quota exceeded -- please try again later")
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
def load_video(ip):
vid, *_ = read_video(ip)
vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8)
return vid
def get_data_from_str(input_str,nprc = None):
assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory'
vid_filelist = glob.glob(os.path.join(input_str,'*.mp4'))
print(f'Found {len(vid_filelist)} videos in dir {input_str}')
if nprc is None:
try:
nprc = mp.cpu_count()
except NotImplementedError:
print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading')
nprc = 1
pool = mp.Pool(processes=nprc)
vids = []
for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'):
vids.append(v)
vids = torch.stack(vids,dim=0).float()
return vids
def get_stats(stats):
assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}'
print(f'Using precomputed statistics under {stats}')
stats = np.load(stats)
stats = {key: stats[key] for key in stats.files}
return stats
@torch.no_grad()
def compute_fvd(ref_input, sample_input, bs=32,
ref_stats=None,
sample_stats=None,
nprc_load=None):
calc_stats = ref_stats is None or sample_stats is None
if calc_stats:
only_ref = sample_stats is not None
only_sample = ref_stats is not None
if isinstance(ref_input,str) and not only_sample:
ref_input = get_data_from_str(ref_input,nprc_load)
if isinstance(sample_input, str) and not only_ref:
sample_input = get_data_from_str(sample_input, nprc_load)
stats = compute_statistics(sample_input,ref_input,
device='cuda' if torch.cuda.is_available() else 'cpu',
bs=bs,
only_ref=only_ref,
only_sample=only_sample)
if only_ref:
stats.update(get_stats(sample_stats))
elif only_sample:
stats.update(get_stats(ref_stats))
else:
stats = get_stats(sample_stats)
stats.update(get_stats(ref_stats))
fvd = compute_frechet_distance(**stats)
return {'FVD' : fvd,}
@torch.no_grad()
def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict:
detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
with open_url(detector_url, verbose=False) as f:
detector = torch.jit.load(f).eval().to(device)
assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive'
ref_embed, sample_embed = [], []
info = f'Computing I3D activations for FVD score with batch size {bs}'
if only_ref:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
print(videos_real.shape)
if videos_real.shape[0] % bs == 0:
n_secs = videos_real.shape[0] // bs
else:
n_secs = videos_real.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for ref_v in tqdm(videos_real, total=len(videos_real),desc=info):
feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
ref_embed.append(feats_ref)
elif only_sample:
if not isvideo(videos_fake):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
print(videos_fake.shape)
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info):
feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
sample_embed.append(feats_sample)
else:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
if not isvideo(videos_fake):
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info):
# print(ref_v.shape)
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
sample_embed.append(feats_sample)
ref_embed.append(feats_ref)
out = dict()
if len(sample_embed) > 0:
sample_embed = np.concatenate(sample_embed,axis=0)
mu_sample, sigma_sample = compute_stats(sample_embed)
out.update({'mu_sample': mu_sample,
'sigma_sample': sigma_sample})
if len(ref_embed) > 0:
ref_embed = np.concatenate(ref_embed,axis=0)
mu_ref, sigma_ref = compute_stats(ref_embed)
out.update({'mu_ref': mu_ref,
'sigma_ref': sigma_ref})
return out

101
main.py
View file

@ -415,6 +415,107 @@ class CUDACallback(Callback):
pass
class SingleImageLogger(Callback):
"""does not save as grid but as single images"""
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
log_images_kwargs=None, log_always=False):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.TestTubeLogger: self._testtube,
}
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_always = log_always
@rank_zero_only
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
grid = torchvision.utils.make_grid(images[k])
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f"{split}/{k}"
pl_module.logger.experiment.add_image(
tag, grid,
global_step=pl_module.global_step)
@rank_zero_only
def log_local(self, save_dir, split, images,
global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
os.makedirs(root, exist_ok=True)
for k in images:
subroot = os.path.join(root, k)
os.makedirs(subroot, exist_ok=True)
base_count = len(glob.glob(os.path.join(subroot, "*.png")))
for img in images[k]:
if self.rescale:
img = (img + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
img = img.transpose(0, 1).transpose(1, 2).squeeze(-1)
img = img.numpy()
img = (img * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}_{:08}.png".format(
k,
global_step,
current_epoch,
batch_idx,
base_count)
path = os.path.join(subroot, filename)
Image.fromarray(img).save(path)
base_count += 1
def log_img(self, pl_module, batch, batch_idx, split="train", save_dir=None):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0) or self.log_always:
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir if save_dir is None else save_dir, split, images,
pl_module.global_step, pl_module.current_epoch, batch_idx)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
check_idx > 0 or self.log_first_step):
try:
self.log_steps.pop(0)
except IndexError as e:
print(e)
return True
return False
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.

View file

@ -0,0 +1,96 @@
import argparse, os, sys, glob
import numpy as np
from torch_fidelity import calculate_metrics
import yaml
from ldm.modules.evaluate.evaluate_perceptualsim import compute_perceptual_similarity_from_list
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--logdir",
type=str,
nargs="?",
default="fidelity-evaluation",
)
parser.add_argument(
"--reconstructions",
type=str,
help="path to reconstructed images"
)
parser.add_argument(
"--inputs",
type=str,
help="path to input images"
)
parser.add_argument(
"--cache_root",
type=str,
help="optional, for pre-computed fidelity statistics",
nargs="?",
)
return parser
if __name__ == "__main__":
command = " ".join(sys.argv)
np.random.RandomState(42)
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
outdir = os.path.join(opt.logdir, "metrics")
print(outdir)
inppath = opt.inputs
recpath = opt.reconstructions
results = dict()
##### fid
fid_kwargs = {}
cache_root = None
if opt.cache_root and os.path.isdir(opt.cache_root):
print(f'Using cached Inception Features saved under "{cache_root}"')
fid_kwargs.update({
'cache_root': cache_root,
'input2_cache_name': 'input_data',
'cache': True
})
metrics_dict = calculate_metrics(input1=recpath, input2=inppath,
cuda=True, isc=True, fid=True, kid=True,
verbose=True, **fid_kwargs)
results["fidelity"] = metrics_dict
print(f'Metrics from fidelity: \n {results["fidelity"]}')
##### sim
print("Evaluating reconstruction similarity")
reconstructions = sorted(glob.glob(os.path.join(recpath, "*.png")))
print(f"num reconstructions found: {len(reconstructions)}")
inputs = sorted(glob.glob(os.path.join(inppath, "*.png")))
print(f"num inputs found: {len(inputs)}")
results["image-sim"] = compute_perceptual_similarity_from_list(
reconstructions, inputs, take_every_other=False)
print(f'Results sim: {results["image-sim"]}')
# info
results["info"] = {
"n_examples": len(reconstructions),
"command": command,
}
# write out
ipath, rpath = map(lambda x: os.path.splitext(x)[0].split(os.sep)[-1], (inppath, recpath))
resultsfn = f"results_{ipath}-{rpath}.yaml"
results_file = os.path.join(outdir, resultsfn)
with open(results_file, 'w') as f:
yaml.dump(results, f, default_flow_style=False)
print(results_file)
print("\ndone.")

267
scripts/logging_template.py Normal file
View file

@ -0,0 +1,267 @@
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
import streamlit as st
from streamlit import caching
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
from tqdm import tqdm
import datetime
from ldm.util import instantiate_from_config
from main import DataModuleFromConfig, ImageLogger, SingleImageLogger
rescale = lambda x: (x + 1.) / 2.
class DummyLogger:
pass
def bchw_to_st(x):
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
def run(model, dsets, callbacks, logdir, split="train",
batch_size=8, start_index=0, sample_batch=False, nowname="", use_full_data=False):
logdir = os.path.join(logdir, nowname)
os.makedirs(logdir, exist_ok=True)
dset = dsets.datasets[split]
print(f"Dataset size: {len(dset)}")
dloader = torch.utils.data.DataLoader(dset, batch_size=opt.batch_size, drop_last=False, shuffle=False)
if not use_full_data:
if sample_batch:
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index+batch_size))
print(f"Data indices: {list(indices)}")
example = default_collate([dset[i] for i in indices])
for cb in callbacks:
if isinstance(cb, ImageLogger):
print(f"logging with {cb.__class__.__name__}")
cb.log_img(model, example, 0, split=split, save_dir=logdir)
else:
for batch in tqdm(dloader, desc="Data"):
for cb in callbacks:
if isinstance(cb, SingleImageLogger):
cb.log_img(model, batch, 0, split=split, save_dir=logdir)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"-n",
"--n_iter",
type=int,
default=1,
help="how many times to run",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="how many examples in the batch",
)
parser.add_argument(
"--split",
type=str,
default="validation",
help="evaluate on this split",
)
parser.add_argument(
"--logdir",
type=str,
default="eval_logs",
help="where to save the logs",
)
parser.add_argument(
"--state_key",
type=str,
default="state_dict",
choices=["state_dict", "model_ema", "model"],
help="where to access the model weights",
)
parser.add_argument(
"--full_data",
action='store_true',
help="evaluate on full dataset",
)
parser.add_argument(
"--ignore_callbacks",
action='store_true',
help="ignores all callbacks in the config and only uses main.SingleImageLogger",
)
return parser
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
print("loading model from state-dict...")
if sd is not None:
m, u = model.load_state_dict(sd)
if len(m) > 0: print(f"missing keys: \n {m}")
if len(u) > 0: print(f"unexpected keys: \n {u}")
print("loaded model.")
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def get_data(config):
# get data
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
return data
def get_callbacks(lightning_config, ignore_callbacks=False):
callbacks_cfg = lightning_config.callbacks
callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
print(f"found and instantiated the following callback(s):")
for cb in callbacks:
print(f" > {cb.__class__.__name__}")
print()
if len(callbacks) == 0 or ignore_callbacks:
del callbacks
callbacks = list()
print("No callbacks found. Falling back to SingleImageLogger as a default")
try:
callbacks.append(SingleImageLogger(1, max_images=opt.batch_size, log_always=True,
log_images_kwargs=lightning_config.callbacks.image_logger.params.log_images_kwargs))
except:
print("No log_images_kwargs specified. Using SingleImageLogger with default values in log_images().")
callbacks.append(SingleImageLogger(1, max_images=opt.batch_size, log_always=True))
return callbacks
@st.cache(allow_output_mutation=True)
def load_model_and_dset(config, ckpt, gpu, eval_mode):
# get data
dsets = get_data(config) # calls data.config ...
# now load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
try:
global_step = pl_sd["global_step"]
except:
global_step = 0
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
#pl_sd["state_dict"],
pl_sd[opt.state_key],
gpu=gpu,
eval_mode=eval_mode)["model"]
return dsets, model, global_step
def exists(x):
return x is not None
if __name__ == "__main__":
sys.path.append(os.getcwd())
if not st._is_running_with_streamlit:
print("Not running with streamlit. Redefining st functions...")
st.info = print
st.write = print
seed_everything(42)
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
assert opt.resume
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-lightning.yaml")))
lightning_configs = [OmegaConf.load(lcfg) for lcfg in lightning_configs]
lightning_config = OmegaConf.merge(*lightning_configs, cli)
print(f"ckpt-path: {ckpt}")
print(config)
print(lightning_config)
gpu = True
eval_mode = True
callbacks = get_callbacks(lightning_config.lightning, ignore_callbacks=opt.ignore_callbacks)
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
print(f"global step: {global_step}")
logdir = os.path.join(logdir, opt.logdir, f"{global_step:09}")
print(f"logging to {logdir}")
os.makedirs(logdir, exist_ok=True)
# go
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
for n in range(opt.n_iter):
nowname = now + "_iteration-" + f"{n:03}"
run(model, dsets, callbacks, logdir=logdir, batch_size=opt.batch_size, nowname=nowname,
split=opt.split, use_full_data=opt.full_data)