267 lines
8.4 KiB
Python
267 lines
8.4 KiB
Python
import argparse, os, sys, glob
|
|
import torch
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
import streamlit as st
|
|
from streamlit import caching
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.dataloader import default_collate
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning import seed_everything
|
|
from pytorch_lightning.callbacks import Callback
|
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
from tqdm import tqdm
|
|
import datetime
|
|
|
|
from ldm.util import instantiate_from_config
|
|
from main import DataModuleFromConfig, ImageLogger, SingleImageLogger
|
|
|
|
rescale = lambda x: (x + 1.) / 2.
|
|
|
|
class DummyLogger:
|
|
pass
|
|
|
|
def bchw_to_st(x):
|
|
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
|
|
|
|
|
|
def run(model, dsets, callbacks, logdir, split="train",
|
|
batch_size=8, start_index=0, sample_batch=False, nowname="", use_full_data=False):
|
|
logdir = os.path.join(logdir, nowname)
|
|
os.makedirs(logdir, exist_ok=True)
|
|
|
|
dset = dsets.datasets[split]
|
|
print(f"Dataset size: {len(dset)}")
|
|
dloader = torch.utils.data.DataLoader(dset, batch_size=opt.batch_size, drop_last=False, shuffle=False)
|
|
if not use_full_data:
|
|
if sample_batch:
|
|
indices = np.random.choice(len(dset), batch_size)
|
|
else:
|
|
indices = list(range(start_index, start_index+batch_size))
|
|
print(f"Data indices: {list(indices)}")
|
|
example = default_collate([dset[i] for i in indices])
|
|
for cb in callbacks:
|
|
if isinstance(cb, ImageLogger):
|
|
print(f"logging with {cb.__class__.__name__}")
|
|
cb.log_img(model, example, 0, split=split, save_dir=logdir)
|
|
else:
|
|
for batch in tqdm(dloader, desc="Data"):
|
|
for cb in callbacks:
|
|
if isinstance(cb, SingleImageLogger):
|
|
cb.log_img(model, batch, 0, split=split, save_dir=logdir)
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-r",
|
|
"--resume",
|
|
type=str,
|
|
nargs="?",
|
|
help="load from logdir or checkpoint in logdir",
|
|
)
|
|
parser.add_argument(
|
|
"-b",
|
|
"--base",
|
|
nargs="*",
|
|
metavar="base_config.yaml",
|
|
help="paths to base configs. Loaded from left-to-right. "
|
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
|
default=list(),
|
|
)
|
|
parser.add_argument(
|
|
"-c",
|
|
"--config",
|
|
nargs="?",
|
|
metavar="single_config.yaml",
|
|
help="path to single config. If specified, base configs will be ignored "
|
|
"(except for the last one if left unspecified).",
|
|
const=True,
|
|
default="",
|
|
)
|
|
parser.add_argument(
|
|
"-n",
|
|
"--n_iter",
|
|
type=int,
|
|
default=1,
|
|
help="how many times to run",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size",
|
|
type=int,
|
|
default=4,
|
|
help="how many examples in the batch",
|
|
)
|
|
parser.add_argument(
|
|
"--split",
|
|
type=str,
|
|
default="validation",
|
|
help="evaluate on this split",
|
|
)
|
|
parser.add_argument(
|
|
"--logdir",
|
|
type=str,
|
|
default="eval_logs",
|
|
help="where to save the logs",
|
|
)
|
|
parser.add_argument(
|
|
"--state_key",
|
|
type=str,
|
|
default="state_dict",
|
|
choices=["state_dict", "model_ema", "model"],
|
|
help="where to access the model weights",
|
|
)
|
|
parser.add_argument(
|
|
"--full_data",
|
|
action='store_true',
|
|
help="evaluate on full dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--ignore_callbacks",
|
|
action='store_true',
|
|
help="ignores all callbacks in the config and only uses main.SingleImageLogger",
|
|
)
|
|
return parser
|
|
|
|
|
|
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
|
model = instantiate_from_config(config)
|
|
print("loading model from state-dict...")
|
|
if sd is not None:
|
|
m, u = model.load_state_dict(sd)
|
|
if len(m) > 0: print(f"missing keys: \n {m}")
|
|
if len(u) > 0: print(f"unexpected keys: \n {u}")
|
|
print("loaded model.")
|
|
if gpu:
|
|
model.cuda()
|
|
if eval_mode:
|
|
model.eval()
|
|
return {"model": model}
|
|
|
|
|
|
def get_data(config):
|
|
# get data
|
|
data = instantiate_from_config(config.data)
|
|
data.prepare_data()
|
|
data.setup()
|
|
return data
|
|
|
|
|
|
def get_callbacks(lightning_config, ignore_callbacks=False):
|
|
callbacks_cfg = lightning_config.callbacks
|
|
callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
|
print(f"found and instantiated the following callback(s):")
|
|
for cb in callbacks:
|
|
print(f" > {cb.__class__.__name__}")
|
|
print()
|
|
if len(callbacks) == 0 or ignore_callbacks:
|
|
del callbacks
|
|
callbacks = list()
|
|
print("No callbacks found. Falling back to SingleImageLogger as a default")
|
|
try:
|
|
callbacks.append(SingleImageLogger(1, max_images=opt.batch_size, log_always=True,
|
|
log_images_kwargs=lightning_config.callbacks.image_logger.params.log_images_kwargs))
|
|
except:
|
|
print("No log_images_kwargs specified. Using SingleImageLogger with default values in log_images().")
|
|
callbacks.append(SingleImageLogger(1, max_images=opt.batch_size, log_always=True))
|
|
return callbacks
|
|
|
|
|
|
@st.cache(allow_output_mutation=True)
|
|
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
|
# get data
|
|
dsets = get_data(config) # calls data.config ...
|
|
|
|
# now load the specified checkpoint
|
|
if ckpt:
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
try:
|
|
global_step = pl_sd["global_step"]
|
|
except:
|
|
global_step = 0
|
|
else:
|
|
pl_sd = {"state_dict": None}
|
|
global_step = None
|
|
model = load_model_from_config(config.model,
|
|
#pl_sd["state_dict"],
|
|
pl_sd[opt.state_key],
|
|
gpu=gpu,
|
|
eval_mode=eval_mode)["model"]
|
|
return dsets, model, global_step
|
|
|
|
|
|
def exists(x):
|
|
return x is not None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.path.append(os.getcwd())
|
|
if not st._is_running_with_streamlit:
|
|
print("Not running with streamlit. Redefining st functions...")
|
|
st.info = print
|
|
st.write = print
|
|
|
|
seed_everything(42)
|
|
parser = get_parser()
|
|
|
|
opt, unknown = parser.parse_known_args()
|
|
|
|
ckpt = None
|
|
assert opt.resume
|
|
if not os.path.exists(opt.resume):
|
|
raise ValueError("Cannot find {}".format(opt.resume))
|
|
|
|
if os.path.isfile(opt.resume):
|
|
paths = opt.resume.split("/")
|
|
try:
|
|
idx = len(paths)-paths[::-1].index("logs")+1
|
|
except ValueError:
|
|
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
|
logdir = "/".join(paths[:idx])
|
|
ckpt = opt.resume
|
|
else:
|
|
assert os.path.isdir(opt.resume), opt.resume
|
|
logdir = opt.resume.rstrip("/")
|
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
|
|
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
|
opt.base = base_configs+opt.base
|
|
|
|
if opt.config:
|
|
if type(opt.config) == str:
|
|
opt.base = [opt.config]
|
|
else:
|
|
opt.base = [opt.base[-1]]
|
|
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
cli = OmegaConf.from_dotlist(unknown)
|
|
config = OmegaConf.merge(*configs, cli)
|
|
|
|
lightning_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-lightning.yaml")))
|
|
lightning_configs = [OmegaConf.load(lcfg) for lcfg in lightning_configs]
|
|
lightning_config = OmegaConf.merge(*lightning_configs, cli)
|
|
|
|
print(f"ckpt-path: {ckpt}")
|
|
|
|
print(config)
|
|
print(lightning_config)
|
|
|
|
gpu = True
|
|
eval_mode = True
|
|
|
|
callbacks = get_callbacks(lightning_config.lightning, ignore_callbacks=opt.ignore_callbacks)
|
|
|
|
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
|
print(f"global step: {global_step}")
|
|
|
|
logdir = os.path.join(logdir, opt.logdir, f"{global_step:09}")
|
|
print(f"logging to {logdir}")
|
|
os.makedirs(logdir, exist_ok=True)
|
|
|
|
# go
|
|
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
|
for n in range(opt.n_iter):
|
|
nowname = now + "_iteration-" + f"{n:03}"
|
|
run(model, dsets, callbacks, logdir=logdir, batch_size=opt.batch_size, nowname=nowname,
|
|
split=opt.split, use_full_data=opt.full_data)
|