stylegan3/runs.py
2023-01-12 15:54:18 +01:00

224 lines
No EOL
7.1 KiB
Python

import os
import datetime
import json
from typing import List, Optional
from PIL import Image
from enum import Enum
import logging
import numpy as np
import dnnlib
import legacy
from dataset_tool import open_dataset
logger = logging.getLogger('runs')
def jsonlines(filename):
# quick n dirty way to load jsonlines file
with open(filename, 'r') as fp:
for line in fp:
yield json.loads(line)
class Snapshot():
def __init__(self, run, metrics):
self.run = run
self.metrics = metrics
self.iteration = int(metrics["snapshot_pkl"][17:-4])
self.iteration_str = metrics["snapshot_pkl"][17:-4]
@property
def id(self):
return f"{self.run.as_nr}_{self.iteration_str}"
@property
def fid(self):
"""Fréchet inception distance, as calculated during training"""
return self.metrics['results']['fid50k_full']
@property
def cumulative_iteration(self):
"""Iteration nr, taking into account the snapshot the run.resumed_from"""
if self.run.resumed_from is None:
return self.iteration
return self.run.resumed_from.iteration + self.iteration
@property
def time(self):
return datetime.datetime.fromtimestamp(int(self.metrics['timestamp']))
@property
def pkl_path(self):
return os.path.join(self.run.directory, f"network-snapshot-{self.iteration_str}.pkl")
def load_generator(self, device):
with dnnlib.util.open_url(self.pkl_path) as f:
return legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
def get_preview_img(self, cols = 1, rows = 1) -> Image:
file = os.path.join(self.run.directory, f"fakes{self.iteration_str}.png")
img = Image.open(file)
return img.crop((0,0, self.run.resolution * cols, self.run.resolution * rows))
class Run():
def __init__(self, directory):
self.directory = directory
self.id = os.path.basename(directory)
self.metric_path = os.path.join(self.directory, 'metric-fid50k_full.jsonl')
self.options_path = os.path.join(self.directory, 'training_options.json')
self.stats_path = os.path.join(self.directory, 'stats.jsonl')
with open (self.options_path) as fp:
self.training_options = json.load(fp)
self.resumed_from = None
if 'resume_pkl' in self.training_options:
resume_from_dir = os.path.dirname(self.training_options['resume_pkl'])
try:
self.resumed_from = [
s for s in
Run(resume_from_dir).snapshots
if os.path.abspath(s.pkl_path) == os.path.abspath(self.training_options['resume_pkl'])
][0]
except:
logger.warning("Could not load parent snapshot")
logger.debug()
if os.path.exists(self.metric_path):
self.snapshots = [Snapshot(self, l) for l in jsonlines(self.metric_path)]
else:
self.snapshots = []
@property
def as_nr(self):
return self.id[:5]
@property
def duration(self):
return self.snapshots[-1].time - self.snapshots[0].time
@property
def kimg_offset(self):
if not self.resumed_from:
return 0
return self.resumed_from.iteration
def get_stats(self):
"""fetch stats from stats.jsonl file
Each stats has `num` (nr. of datapoints),
`mean` (mean of points), `std` (std dev)
yields each line
"""
yield from jsonlines(self.stats_path)
def is_empty(self):
return len(self.snapshots) < 1
# def get_fids(self) -> dict:
# return {:l['results']['fid50k_full'] for l in jsonlines(self.metric_path)}
# @property
# def fakes(self):
# return sorted([f for f in os.listdir(rundir) if f.startswith('fake')])
@property
def dataset_id(self):
return list(filter(None, self.training_options["training_set_kwargs"]["path"].split(os.path.sep)))[-1]
def dataset_is_conditional(self):
return bool(self.training_options["training_set_kwargs"]["use_labels"])
def dataset_iterator(self, max_images: Optional[int] = None):
max_images, iterator = open_dataset(
self.training_options["training_set_kwargs"]["path"],
max_images=max_images
)
return iterator
@property
def resolution(self):
return self.training_options["training_set_kwargs"]["resolution"]
@property
def r1_gamma(self):
return self.training_options["loss_kwargs"]["r1_gamma"]
def get_summary(self):
return {
# "name": self.id,
"nr": self.as_nr,
"dataset": self.dataset_id,
"conditional": self.dataset_is_conditional(),
"resolution": self.resolution,
"gamma": self.r1_gamma,
"duration": self.duration,
# "finished": self.snapshots[-1].time,
"iterations": self.snapshots[-1].iteration,
"last_fid": self.snapshots[-1].fid
}
def get_runs_in_dir(dir_path, include_empty = False) -> List[Run]:
run_dirs = sorted(os.listdir(dir_path))
runs = []
for run_dir in run_dirs:
run = Run(os.path.join(dir_path, run_dir))
if include_empty or not run.is_empty():
runs.append(run)
return runs
class StreetType(Enum):
RUE = 'Rue'
AVENUE = 'Avenue'
BOULEVARD = 'Boulevard'
class Projection():
# TODO: add snapshot and dataset
def __init__(self, path, identifier, arrondisement: int, street_type: StreetType):
self.path = path
self.id = identifier
self.arrondisement = arrondisement
self.street_type = street_type
@property
def img_path(self):
return os.path.join(self.path, 'proj.png')
@property
def target_img_path(self):
return os.path.join(self.path, 'target.png')
@property
def w_path(self):
return os.path.join(self.path, 'projected_w.npz')
def load_w(self):
with np.load(self.w_path) as data:
return data['w']
@classmethod
def from_path(cls, path):
dirname = list(filter(None, path.split('/')))[-1]
parts = dirname.split('-')
arrondisement = int(parts[0])
street_type = None
for t in StreetType:
if parts[1].startswith(t.value):
street_type = t
break
if street_type is None:
raise Exception(f"Unable to determine street type for {path}")
return cls(path, dirname, arrondisement, street_type)
# for StreetType.
# street_type =
def get_projections_in_dir(projection_folder) -> List[Projection]:
projection_paths = [os.path.join(projection_folder, p) for p in os.listdir(projection_folder) if os.path.exists(os.path.join(projection_folder, p, "projected_w.npz"))]
return [Projection.from_path(p) for p in projection_paths]