stylegan3/runs.py

216 lines
6.8 KiB
Python
Raw Normal View History

2022-11-25 18:54:23 +00:00
import os
import datetime
import json
from typing import List
from PIL import Image
from enum import Enum
import logging
import numpy as np
import dnnlib
import legacy
logger = logging.getLogger('runs')
def jsonlines(filename):
# quick n dirty way to load jsonlines file
with open(filename, 'r') as fp:
for line in fp:
yield json.loads(line)
class Snapshot():
def __init__(self, run, metrics):
self.run = run
self.metrics = metrics
self.iteration = int(metrics["snapshot_pkl"][17:-4])
self.iteration_str = metrics["snapshot_pkl"][17:-4]
@property
def id(self):
return f"{self.run.as_nr}_{self.iteration_str}"
@property
def fid(self):
"""Fréchet inception distance, as calculated during training"""
return self.metrics['results']['fid50k_full']
@property
def cumulative_iteration(self):
"""Iteration nr, taking into account the snapshot the run.resumed_from"""
if self.run.resumed_from is None:
return self.iteration
return self.run.resumed_from.iteration + self.iteration
@property
def time(self):
return datetime.datetime.fromtimestamp(int(self.metrics['timestamp']))
@property
def pkl_path(self):
return os.path.join(self.run.directory, f"network-snapshot-{self.iteration_str}.pkl")
def load_generator(self, device):
with dnnlib.util.open_url(self.pkl_path) as f:
return legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
def get_preview_img(self, cols = 1, rows = 1) -> Image:
file = os.path.join(self.run.directory, f"fakes{self.iteration_str}.png")
img = Image.open(file)
return img.crop((0,0, self.run.resolution * cols, self.run.resolution * rows))
class Run():
def __init__(self, directory):
self.directory = directory
self.id = os.path.basename(directory)
self.metric_path = os.path.join(self.directory, 'metric-fid50k_full.jsonl')
self.options_path = os.path.join(self.directory, 'training_options.json')
self.stats_path = os.path.join(self.directory, 'stats.jsonl')
with open (self.options_path) as fp:
self.training_options = json.load(fp)
self.resumed_from = None
if 'resume_pkl' in self.training_options:
resume_from_dir = os.path.dirname(self.training_options['resume_pkl'])
try:
self.resumed_from = [
s for s in
Run(resume_from_dir).snapshots
if os.path.abspath(s.pkl_path) == os.path.abspath(self.training_options['resume_pkl'])
][0]
except:
logger.warning("Could not load parent snapshot")
logger.debug()
if os.path.exists(self.metric_path):
self.snapshots = [Snapshot(self, l) for l in jsonlines(self.metric_path)]
else:
self.snapshots = []
@property
def as_nr(self):
return self.id[:5]
@property
def duration(self):
return self.snapshots[-1].time - self.snapshots[0].time
@property
def kimg_offset(self):
if not self.resumed_from:
return 0
return self.resumed_from.iteration
def get_stats(self):
"""fetch stats from stats.jsonl file
Each stats has `num` (nr. of datapoints),
`mean` (mean of points), `std` (std dev)
yields each line
"""
yield from jsonlines(self.stats_path)
def is_empty(self):
return len(self.snapshots) < 1
# def get_fids(self) -> dict:
# return {:l['results']['fid50k_full'] for l in jsonlines(self.metric_path)}
# @property
# def fakes(self):
# return sorted([f for f in os.listdir(rundir) if f.startswith('fake')])
@property
def dataset_id(self):
return list(filter(None, self.training_options["training_set_kwargs"]["path"].split(os.path.sep)))[-1]
def dataset_is_conditional(self):
return bool(self.training_options["training_set_kwargs"]["use_labels"])
@property
def resolution(self):
return self.training_options["training_set_kwargs"]["resolution"]
@property
def r1_gamma(self):
return self.training_options["loss_kwargs"]["r1_gamma"]
def get_summary(self):
return {
# "name": self.id,
"nr": self.as_nr,
"dataset": self.dataset_id,
"conditional": self.dataset_is_conditional(),
"resolution": self.resolution,
"gamma": self.r1_gamma,
"duration": self.duration,
# "finished": self.snapshots[-1].time,
"iterations": self.snapshots[-1].iteration,
"last_fid": self.snapshots[-1].fid
}
def get_runs_in_dir(dir_path, include_empty = False) -> List[Run]:
run_dirs = sorted(os.listdir(dir_path))
runs = []
for run_dir in run_dirs:
run = Run(os.path.join(dir_path, run_dir))
if include_empty or not run.is_empty():
runs.append(run)
return runs
class StreetType(Enum):
RUE = 'Rue'
AVENUE = 'Avenue'
BOULEVARD = 'Boulevard'
class Projection():
# TODO: add snapshot and dataset
def __init__(self, path, identifier, arrondisement: int, street_type: StreetType):
self.path = path
self.id = identifier
self.arrondisement = arrondisement
self.street_type = street_type
@property
def img_path(self):
return os.path.join(self.path, 'proj.png')
@property
def target_img_path(self):
return os.path.join(self.path, 'target.png')
@property
def w_path(self):
return os.path.join(self.path, 'projected_w.npz')
def load_w(self):
with np.load(self.w_path) as data:
return data['w']
@classmethod
def from_path(cls, path):
dirname = list(filter(None, path.split('/')))[-1]
parts = dirname.split('-')
arrondisement = int(parts[0])
street_type = None
for t in StreetType:
if parts[1].startswith(t.value):
street_type = t
break
if street_type is None:
raise Exception(f"Unable to determine street type for {path}")
return cls(path, dirname, arrondisement, street_type)
# for StreetType.
# street_type =
def get_projections_in_dir(projection_folder) -> List[Projection]:
projection_paths = [os.path.join(projection_folder, p) for p in os.listdir(projection_folder) if os.path.exists(os.path.join(projection_folder, p, "projected_w.npz"))]
return [Projection.from_path(p) for p in projection_paths]