import os import datetime import json from typing import List from PIL import Image from enum import Enum import logging import numpy as np import dnnlib import legacy logger = logging.getLogger('runs') def jsonlines(filename): # quick n dirty way to load jsonlines file with open(filename, 'r') as fp: for line in fp: yield json.loads(line) class Snapshot(): def __init__(self, run, metrics): self.run = run self.metrics = metrics self.iteration = int(metrics["snapshot_pkl"][17:-4]) self.iteration_str = metrics["snapshot_pkl"][17:-4] @property def id(self): return f"{self.run.as_nr}_{self.iteration_str}" @property def fid(self): """Fréchet inception distance, as calculated during training""" return self.metrics['results']['fid50k_full'] @property def cumulative_iteration(self): """Iteration nr, taking into account the snapshot the run.resumed_from""" if self.run.resumed_from is None: return self.iteration return self.run.resumed_from.iteration + self.iteration @property def time(self): return datetime.datetime.fromtimestamp(int(self.metrics['timestamp'])) @property def pkl_path(self): return os.path.join(self.run.directory, f"network-snapshot-{self.iteration_str}.pkl") def load_generator(self, device): with dnnlib.util.open_url(self.pkl_path) as f: return legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore def get_preview_img(self, cols = 1, rows = 1) -> Image: file = os.path.join(self.run.directory, f"fakes{self.iteration_str}.png") img = Image.open(file) return img.crop((0,0, self.run.resolution * cols, self.run.resolution * rows)) class Run(): def __init__(self, directory): self.directory = directory self.id = os.path.basename(directory) self.metric_path = os.path.join(self.directory, 'metric-fid50k_full.jsonl') self.options_path = os.path.join(self.directory, 'training_options.json') self.stats_path = os.path.join(self.directory, 'stats.jsonl') with open (self.options_path) as fp: self.training_options = json.load(fp) self.resumed_from = None if 'resume_pkl' in self.training_options: resume_from_dir = os.path.dirname(self.training_options['resume_pkl']) try: self.resumed_from = [ s for s in Run(resume_from_dir).snapshots if os.path.abspath(s.pkl_path) == os.path.abspath(self.training_options['resume_pkl']) ][0] except: logger.warning("Could not load parent snapshot") logger.debug() if os.path.exists(self.metric_path): self.snapshots = [Snapshot(self, l) for l in jsonlines(self.metric_path)] else: self.snapshots = [] @property def as_nr(self): return self.id[:5] @property def duration(self): return self.snapshots[-1].time - self.snapshots[0].time @property def kimg_offset(self): if not self.resumed_from: return 0 return self.resumed_from.iteration def get_stats(self): """fetch stats from stats.jsonl file Each stats has `num` (nr. of datapoints), `mean` (mean of points), `std` (std dev) yields each line """ yield from jsonlines(self.stats_path) def is_empty(self): return len(self.snapshots) < 1 # def get_fids(self) -> dict: # return {:l['results']['fid50k_full'] for l in jsonlines(self.metric_path)} # @property # def fakes(self): # return sorted([f for f in os.listdir(rundir) if f.startswith('fake')]) @property def dataset_id(self): return list(filter(None, self.training_options["training_set_kwargs"]["path"].split(os.path.sep)))[-1] def dataset_is_conditional(self): return bool(self.training_options["training_set_kwargs"]["use_labels"]) @property def resolution(self): return self.training_options["training_set_kwargs"]["resolution"] @property def r1_gamma(self): return self.training_options["loss_kwargs"]["r1_gamma"] def get_summary(self): return { # "name": self.id, "nr": self.as_nr, "dataset": self.dataset_id, "conditional": self.dataset_is_conditional(), "resolution": self.resolution, "gamma": self.r1_gamma, "duration": self.duration, # "finished": self.snapshots[-1].time, "iterations": self.snapshots[-1].iteration, "last_fid": self.snapshots[-1].fid } def get_runs_in_dir(dir_path, include_empty = False) -> List[Run]: run_dirs = sorted(os.listdir(dir_path)) runs = [] for run_dir in run_dirs: run = Run(os.path.join(dir_path, run_dir)) if include_empty or not run.is_empty(): runs.append(run) return runs class StreetType(Enum): RUE = 'Rue' AVENUE = 'Avenue' BOULEVARD = 'Boulevard' class Projection(): # TODO: add snapshot and dataset def __init__(self, path, identifier, arrondisement: int, street_type: StreetType): self.path = path self.id = identifier self.arrondisement = arrondisement self.street_type = street_type @property def img_path(self): return os.path.join(self.path, 'proj.png') @property def target_img_path(self): return os.path.join(self.path, 'target.png') @property def w_path(self): return os.path.join(self.path, 'projected_w.npz') def load_w(self): with np.load(self.w_path) as data: return data['w'] @classmethod def from_path(cls, path): dirname = list(filter(None, path.split('/')))[-1] parts = dirname.split('-') arrondisement = int(parts[0]) street_type = None for t in StreetType: if parts[1].startswith(t.value): street_type = t break if street_type is None: raise Exception(f"Unable to determine street type for {path}") return cls(path, dirname, arrondisement, street_type) # for StreetType. # street_type = def get_projections_in_dir(projection_folder) -> List[Projection]: projection_paths = [os.path.join(projection_folder, p) for p in os.listdir(projection_folder) if os.path.exists(os.path.join(projection_folder, p, "projected_w.npz"))] return [Projection.from_path(p) for p in projection_paths]