216 lines
No EOL
6.8 KiB
Python
216 lines
No EOL
6.8 KiB
Python
import os
|
|
import datetime
|
|
import json
|
|
from typing import List
|
|
from PIL import Image
|
|
from enum import Enum
|
|
import logging
|
|
import numpy as np
|
|
import dnnlib
|
|
import legacy
|
|
|
|
logger = logging.getLogger('runs')
|
|
|
|
def jsonlines(filename):
|
|
# quick n dirty way to load jsonlines file
|
|
with open(filename, 'r') as fp:
|
|
for line in fp:
|
|
yield json.loads(line)
|
|
|
|
class Snapshot():
|
|
def __init__(self, run, metrics):
|
|
self.run = run
|
|
self.metrics = metrics
|
|
self.iteration = int(metrics["snapshot_pkl"][17:-4])
|
|
self.iteration_str = metrics["snapshot_pkl"][17:-4]
|
|
|
|
|
|
@property
|
|
def id(self):
|
|
return f"{self.run.as_nr}_{self.iteration_str}"
|
|
|
|
@property
|
|
def fid(self):
|
|
"""Fréchet inception distance, as calculated during training"""
|
|
return self.metrics['results']['fid50k_full']
|
|
|
|
@property
|
|
def cumulative_iteration(self):
|
|
"""Iteration nr, taking into account the snapshot the run.resumed_from"""
|
|
if self.run.resumed_from is None:
|
|
return self.iteration
|
|
return self.run.resumed_from.iteration + self.iteration
|
|
|
|
|
|
@property
|
|
def time(self):
|
|
return datetime.datetime.fromtimestamp(int(self.metrics['timestamp']))
|
|
|
|
@property
|
|
def pkl_path(self):
|
|
return os.path.join(self.run.directory, f"network-snapshot-{self.iteration_str}.pkl")
|
|
|
|
def load_generator(self, device):
|
|
with dnnlib.util.open_url(self.pkl_path) as f:
|
|
return legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
|
|
|
def get_preview_img(self, cols = 1, rows = 1) -> Image:
|
|
file = os.path.join(self.run.directory, f"fakes{self.iteration_str}.png")
|
|
img = Image.open(file)
|
|
return img.crop((0,0, self.run.resolution * cols, self.run.resolution * rows))
|
|
|
|
class Run():
|
|
def __init__(self, directory):
|
|
self.directory = directory
|
|
self.id = os.path.basename(directory)
|
|
|
|
self.metric_path = os.path.join(self.directory, 'metric-fid50k_full.jsonl')
|
|
self.options_path = os.path.join(self.directory, 'training_options.json')
|
|
self.stats_path = os.path.join(self.directory, 'stats.jsonl')
|
|
|
|
with open (self.options_path) as fp:
|
|
self.training_options = json.load(fp)
|
|
|
|
self.resumed_from = None
|
|
if 'resume_pkl' in self.training_options:
|
|
resume_from_dir = os.path.dirname(self.training_options['resume_pkl'])
|
|
try:
|
|
self.resumed_from = [
|
|
s for s in
|
|
Run(resume_from_dir).snapshots
|
|
if os.path.abspath(s.pkl_path) == os.path.abspath(self.training_options['resume_pkl'])
|
|
][0]
|
|
except:
|
|
logger.warning("Could not load parent snapshot")
|
|
logger.debug()
|
|
|
|
|
|
if os.path.exists(self.metric_path):
|
|
self.snapshots = [Snapshot(self, l) for l in jsonlines(self.metric_path)]
|
|
else:
|
|
self.snapshots = []
|
|
|
|
@property
|
|
def as_nr(self):
|
|
return self.id[:5]
|
|
|
|
@property
|
|
def duration(self):
|
|
return self.snapshots[-1].time - self.snapshots[0].time
|
|
|
|
@property
|
|
def kimg_offset(self):
|
|
if not self.resumed_from:
|
|
return 0
|
|
return self.resumed_from.iteration
|
|
|
|
def get_stats(self):
|
|
"""fetch stats from stats.jsonl file
|
|
Each stats has `num` (nr. of datapoints),
|
|
`mean` (mean of points), `std` (std dev)
|
|
yields each line
|
|
"""
|
|
yield from jsonlines(self.stats_path)
|
|
|
|
def is_empty(self):
|
|
return len(self.snapshots) < 1
|
|
|
|
# def get_fids(self) -> dict:
|
|
# return {:l['results']['fid50k_full'] for l in jsonlines(self.metric_path)}
|
|
|
|
# @property
|
|
# def fakes(self):
|
|
# return sorted([f for f in os.listdir(rundir) if f.startswith('fake')])
|
|
|
|
|
|
@property
|
|
def dataset_id(self):
|
|
return list(filter(None, self.training_options["training_set_kwargs"]["path"].split(os.path.sep)))[-1]
|
|
|
|
|
|
def dataset_is_conditional(self):
|
|
return bool(self.training_options["training_set_kwargs"]["use_labels"])
|
|
|
|
@property
|
|
def resolution(self):
|
|
return self.training_options["training_set_kwargs"]["resolution"]
|
|
@property
|
|
def r1_gamma(self):
|
|
return self.training_options["loss_kwargs"]["r1_gamma"]
|
|
|
|
def get_summary(self):
|
|
return {
|
|
# "name": self.id,
|
|
"nr": self.as_nr,
|
|
"dataset": self.dataset_id,
|
|
"conditional": self.dataset_is_conditional(),
|
|
"resolution": self.resolution,
|
|
"gamma": self.r1_gamma,
|
|
"duration": self.duration,
|
|
# "finished": self.snapshots[-1].time,
|
|
"iterations": self.snapshots[-1].iteration,
|
|
"last_fid": self.snapshots[-1].fid
|
|
}
|
|
|
|
def get_runs_in_dir(dir_path, include_empty = False) -> List[Run]:
|
|
run_dirs = sorted(os.listdir(dir_path))
|
|
runs = []
|
|
for run_dir in run_dirs:
|
|
run = Run(os.path.join(dir_path, run_dir))
|
|
if include_empty or not run.is_empty():
|
|
runs.append(run)
|
|
return runs
|
|
|
|
class StreetType(Enum):
|
|
RUE = 'Rue'
|
|
AVENUE = 'Avenue'
|
|
BOULEVARD = 'Boulevard'
|
|
|
|
class Projection():
|
|
# TODO: add snapshot and dataset
|
|
def __init__(self, path, identifier, arrondisement: int, street_type: StreetType):
|
|
self.path = path
|
|
self.id = identifier
|
|
self.arrondisement = arrondisement
|
|
self.street_type = street_type
|
|
|
|
@property
|
|
def img_path(self):
|
|
return os.path.join(self.path, 'proj.png')
|
|
|
|
|
|
@property
|
|
def target_img_path(self):
|
|
return os.path.join(self.path, 'target.png')
|
|
|
|
|
|
@property
|
|
def w_path(self):
|
|
return os.path.join(self.path, 'projected_w.npz')
|
|
|
|
def load_w(self):
|
|
with np.load(self.w_path) as data:
|
|
return data['w']
|
|
|
|
@classmethod
|
|
def from_path(cls, path):
|
|
dirname = list(filter(None, path.split('/')))[-1]
|
|
parts = dirname.split('-')
|
|
arrondisement = int(parts[0])
|
|
street_type = None
|
|
for t in StreetType:
|
|
if parts[1].startswith(t.value):
|
|
street_type = t
|
|
break
|
|
if street_type is None:
|
|
raise Exception(f"Unable to determine street type for {path}")
|
|
|
|
return cls(path, dirname, arrondisement, street_type)
|
|
# for StreetType.
|
|
# street_type =
|
|
|
|
|
|
|
|
def get_projections_in_dir(projection_folder) -> List[Projection]:
|
|
projection_paths = [os.path.join(projection_folder, p) for p in os.listdir(projection_folder) if os.path.exists(os.path.join(projection_folder, p, "projected_w.npz"))]
|
|
return [Projection.from_path(p) for p in projection_paths] |