import os import datetime import json from typing import List, Optional from PIL import Image from enum import Enum import logging import numpy as np import dnnlib import legacy from dataset_tool import open_dataset logger = logging.getLogger('runs') def jsonlines(filename): # quick n dirty way to load jsonlines file with open(filename, 'r') as fp: for line in fp: yield json.loads(line) class Snapshot(): def __init__(self, run, metrics): self.run = run self.metrics = metrics self.iteration = int(metrics["snapshot_pkl"][17:-4]) self.iteration_str = metrics["snapshot_pkl"][17:-4] @property def id(self): return f"{self.run.as_nr}_{self.iteration_str}" @property def fid(self): """Fréchet inception distance, as calculated during training""" return self.metrics['results']['fid50k_full'] @property def cumulative_iteration(self): """Iteration nr, taking into account the snapshot the run.resumed_from""" if self.run.resumed_from is None: return self.iteration return self.run.resumed_from.iteration + self.iteration @property def time(self): return datetime.datetime.fromtimestamp(int(self.metrics['timestamp'])) @property def pkl_path(self): return os.path.join(self.run.directory, f"network-snapshot-{self.iteration_str}.pkl") def load_generator(self, device): with dnnlib.util.open_url(self.pkl_path) as f: return legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore def get_preview_img(self, cols = 1, rows = 1) -> Image: file = os.path.join(self.run.directory, f"fakes{self.iteration_str}.png") img = Image.open(file) return img.crop((0,0, self.run.resolution * cols, self.run.resolution * rows)) class Run(): def __init__(self, directory): self.directory = directory self.id = os.path.basename(directory) self.metric_path = os.path.join(self.directory, 'metric-fid50k_full.jsonl') self.options_path = os.path.join(self.directory, 'training_options.json') self.stats_path = os.path.join(self.directory, 'stats.jsonl') with open (self.options_path) as fp: self.training_options = json.load(fp) self.resumed_from = None if 'resume_pkl' in self.training_options: resume_from_dir = os.path.dirname(self.training_options['resume_pkl']) try: self.resumed_from = [ s for s in Run(resume_from_dir).snapshots if os.path.abspath(s.pkl_path) == os.path.abspath(self.training_options['resume_pkl']) ][0] except: logger.warning("Could not load parent snapshot") logger.debug() if os.path.exists(self.metric_path): self.snapshots = [Snapshot(self, l) for l in jsonlines(self.metric_path)] else: self.snapshots = [] @property def as_nr(self): return self.id[:5] @property def duration(self): return self.snapshots[-1].time - self.snapshots[0].time @property def kimg_offset(self): if not self.resumed_from: return 0 return self.resumed_from.iteration def get_stats(self): """fetch stats from stats.jsonl file Each stats has `num` (nr. of datapoints), `mean` (mean of points), `std` (std dev) yields each line """ yield from jsonlines(self.stats_path) def is_empty(self): return len(self.snapshots) < 1 # def get_fids(self) -> dict: # return {:l['results']['fid50k_full'] for l in jsonlines(self.metric_path)} # @property # def fakes(self): # return sorted([f for f in os.listdir(rundir) if f.startswith('fake')]) @property def dataset_id(self): return list(filter(None, self.training_options["training_set_kwargs"]["path"].split(os.path.sep)))[-1] def dataset_is_conditional(self): return bool(self.training_options["training_set_kwargs"]["use_labels"]) def dataset_iterator(self, max_images: Optional[int] = None): max_images, iterator = open_dataset( self.training_options["training_set_kwargs"]["path"], max_images=max_images ) return iterator @property def resolution(self): return self.training_options["training_set_kwargs"]["resolution"] @property def r1_gamma(self): return self.training_options["loss_kwargs"]["r1_gamma"] def get_summary(self): return { # "name": self.id, "nr": self.as_nr, "dataset": self.dataset_id, "conditional": self.dataset_is_conditional(), "resolution": self.resolution, "gamma": self.r1_gamma, "duration": self.duration, # "finished": self.snapshots[-1].time, "iterations": self.snapshots[-1].iteration, "last_fid": self.snapshots[-1].fid } def get_runs_in_dir(dir_path, include_empty = False) -> List[Run]: run_dirs = sorted(os.listdir(dir_path)) runs = [] for run_dir in run_dirs: run = Run(os.path.join(dir_path, run_dir)) if include_empty or not run.is_empty(): runs.append(run) return runs class StreetType(Enum): RUE = 'Rue' AVENUE = 'Avenue' BOULEVARD = 'Boulevard' class Projection(): # TODO: add snapshot and dataset def __init__(self, path, identifier, arrondisement: int, street_type: StreetType): self.path = path self.id = identifier self.arrondisement = arrondisement self.street_type = street_type @property def img_path(self): return os.path.join(self.path, 'proj.png') @property def target_img_path(self): return os.path.join(self.path, 'target.png') @property def w_path(self): return os.path.join(self.path, 'projected_w.npz') def load_w(self): with np.load(self.w_path) as data: return data['w'] @classmethod def from_path(cls, path): dirname = list(filter(None, path.split('/')))[-1] parts = dirname.split('-') arrondisement = int(parts[0]) street_type = None for t in StreetType: if parts[1].startswith(t.value): street_type = t break if street_type is None: raise Exception(f"Unable to determine street type for {path}") return cls(path, dirname, arrondisement, street_type) # for StreetType. # street_type = def get_projections_in_dir(projection_folder) -> List[Projection]: projection_paths = [os.path.join(projection_folder, p) for p in os.listdir(projection_folder) if os.path.exists(os.path.join(projection_folder, p, "projected_w.npz"))] return [Projection.from_path(p) for p in projection_paths]