keep sampling
This commit is contained in:
		
							parent
							
								
									a8dcade961
								
							
						
					
					
						commit
						0948a3f89c
					
				
					 1 changed files with 205 additions and 0 deletions
				
			
		
							
								
								
									
										205
									
								
								scripts/checker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										205
									
								
								scripts/checker.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,205 @@
 | 
			
		|||
import os
 | 
			
		||||
import glob
 | 
			
		||||
import subprocess
 | 
			
		||||
import time
 | 
			
		||||
import fire
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch_lightning import seed_everything
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from ldm.util import instantiate_from_config
 | 
			
		||||
from ldm.models.diffusion.plms import PLMSSampler
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from torchvision.utils import make_grid
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import contextlib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model_from_config(config, ckpt, verbose=False):
 | 
			
		||||
    pl_sd = torch.load(ckpt, map_location="cpu")
 | 
			
		||||
    gs = pl_sd["global_step"]
 | 
			
		||||
    sd = pl_sd["state_dict"]
 | 
			
		||||
    model = instantiate_from_config(config.model)
 | 
			
		||||
    m, u = model.load_state_dict(sd, strict=True)
 | 
			
		||||
    model.cuda()
 | 
			
		||||
    model.eval()
 | 
			
		||||
    return model, gs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_prompts(path):
 | 
			
		||||
    with open(path, "r") as f:
 | 
			
		||||
        prompts = f.read().splitlines()
 | 
			
		||||
    return prompts
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_in_batches(iterator, n):
 | 
			
		||||
    out = []
 | 
			
		||||
    for elem in iterator:
 | 
			
		||||
        out.append(elem)
 | 
			
		||||
        if len(out) == n:
 | 
			
		||||
            yield out
 | 
			
		||||
            out = []
 | 
			
		||||
    if len(out) > 0:
 | 
			
		||||
        yield out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Sampler(object):
 | 
			
		||||
    def __init__(self, out_dir, ckpt_path, cfg_path, prompts_path, shape, seed=42):
 | 
			
		||||
        self.out_dir = out_dir
 | 
			
		||||
        self.ckpt_path = ckpt_path
 | 
			
		||||
        self.cfg_path = cfg_path
 | 
			
		||||
        self.prompts_path = prompts_path
 | 
			
		||||
        self.seed = seed
 | 
			
		||||
 | 
			
		||||
        self.batch_size = 1
 | 
			
		||||
        self.scale = 10
 | 
			
		||||
        self.shape = shape
 | 
			
		||||
        self.n_steps = 100
 | 
			
		||||
        self.nrow = 8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def sample(self, model, prompts, ema=True):
 | 
			
		||||
        seed = self.seed
 | 
			
		||||
        batch_size = self.batch_size
 | 
			
		||||
        scale = self.scale
 | 
			
		||||
        n_steps = self.n_steps
 | 
			
		||||
 | 
			
		||||
        shape = self.shape
 | 
			
		||||
 | 
			
		||||
        print("Sampling model.")
 | 
			
		||||
        print("ckpt_path", self.ckpt_path)
 | 
			
		||||
        print("cfg_path", self.cfg_path)
 | 
			
		||||
        print("prompts_path", self.prompts_path)
 | 
			
		||||
        print("out_dir", self.out_dir)
 | 
			
		||||
        print("seed", self.seed)
 | 
			
		||||
        print("batch_size", batch_size)
 | 
			
		||||
        print("scale", scale)
 | 
			
		||||
        print("n_steps", n_steps)
 | 
			
		||||
        print("shape", shape)
 | 
			
		||||
 | 
			
		||||
        prompts = list(split_in_batches(prompts, batch_size))
 | 
			
		||||
 | 
			
		||||
        sampler = PLMSSampler(model)
 | 
			
		||||
        all_samples = list()
 | 
			
		||||
 | 
			
		||||
        ctxt = model.ema_scope if ema else contextlib.nullcontext
 | 
			
		||||
 | 
			
		||||
        with ctxt():
 | 
			
		||||
            for prompts_batch in tqdm(prompts, desc="prompts"):
 | 
			
		||||
                uc = None
 | 
			
		||||
                if scale != 1.0:
 | 
			
		||||
                    uc = model.get_learned_conditioning(batch_size * [""])
 | 
			
		||||
                c = model.get_learned_conditioning(prompts_batch)
 | 
			
		||||
 | 
			
		||||
                seed_everything(seed)
 | 
			
		||||
 | 
			
		||||
                samples_latent, _ = sampler.sample(
 | 
			
		||||
                    S=n_steps,
 | 
			
		||||
                    conditioning=c,
 | 
			
		||||
                    batch_size=batch_size,
 | 
			
		||||
                    shape=shape,
 | 
			
		||||
                    verbose=False,
 | 
			
		||||
                    unconditional_guidance_scale=scale,
 | 
			
		||||
                    unconditional_conditioning=uc,
 | 
			
		||||
                    eta=0.0,
 | 
			
		||||
                    dynamic_threshold=None,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                samples = model.decode_first_stage(samples_latent)
 | 
			
		||||
                samples = torch.clamp((samples+1.0)/2.0, min=0.0, max=1.0)
 | 
			
		||||
 | 
			
		||||
                all_samples.append(samples)
 | 
			
		||||
 | 
			
		||||
        all_samples = torch.cat(all_samples, 0)
 | 
			
		||||
        return all_samples
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def __call__(self):
 | 
			
		||||
        config = OmegaConf.load(self.cfg_path)
 | 
			
		||||
        model, global_step = load_model_from_config(config, self.ckpt_path)
 | 
			
		||||
        print(f"Restored model at global step {global_step}.")
 | 
			
		||||
 | 
			
		||||
        prompts = read_prompts(self.prompts_path)
 | 
			
		||||
 | 
			
		||||
        all_samples = self.sample(model, prompts, ema=True)
 | 
			
		||||
        self.save_as_grid("grid_with_wings", all_samples, global_step)
 | 
			
		||||
        all_samples = self.sample(model, prompts, ema=False)
 | 
			
		||||
        self.save_as_grid("grid_without_wings", all_samples, global_step)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def save_as_grid(self, name, grid, global_step):
 | 
			
		||||
        grid = make_grid(grid, nrow=self.nrow)
 | 
			
		||||
        grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
 | 
			
		||||
 | 
			
		||||
        os.makedirs(self.out_dir, exist_ok=True)
 | 
			
		||||
        filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
 | 
			
		||||
            name,
 | 
			
		||||
            global_step,
 | 
			
		||||
            0,
 | 
			
		||||
            0,
 | 
			
		||||
        )
 | 
			
		||||
        grid_path = os.path.join(self.out_dir, filename)
 | 
			
		||||
        Image.fromarray(grid.astype(np.uint8)).save(grid_path)
 | 
			
		||||
        print(f"---> {grid_path}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Checker(object):
 | 
			
		||||
    def __init__(self, ckpt_path, callback, interval=60):
 | 
			
		||||
        self._cached_stamp = 0
 | 
			
		||||
        self.filename = ckpt_path
 | 
			
		||||
        self.callback = callback
 | 
			
		||||
        self.interval = interval
 | 
			
		||||
 | 
			
		||||
    def check(self):
 | 
			
		||||
        while True:
 | 
			
		||||
            stamp = os.stat(self.filename).st_mtime
 | 
			
		||||
            if stamp != self._cached_stamp:
 | 
			
		||||
                self._cached_stamp = stamp
 | 
			
		||||
                # file has changed, so do something...
 | 
			
		||||
                print(f"{self.__class__.__name__}: Detected a new file at "
 | 
			
		||||
                      f"{self.filename}, calling back.")
 | 
			
		||||
                self.callback()
 | 
			
		||||
            else:
 | 
			
		||||
                time.sleep(self.interval)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run(prompts_path="scripts/prompts/prompts-with-wings.txt",
 | 
			
		||||
        watch_log_dir=None, out_dir=None, ckpt_path=None, cfg_path=None,
 | 
			
		||||
        H=256,
 | 
			
		||||
        W=None,
 | 
			
		||||
        C=4,
 | 
			
		||||
        F=8,
 | 
			
		||||
        interval=60):
 | 
			
		||||
 | 
			
		||||
    if out_dir is None:
 | 
			
		||||
        assert watch_log_dir is not None
 | 
			
		||||
        out_dir = os.path.join(watch_log_dir, "images/checker")
 | 
			
		||||
 | 
			
		||||
    if ckpt_path is None:
 | 
			
		||||
        assert watch_log_dir is not None
 | 
			
		||||
        ckpt_path = os.path.join(watch_log_dir, "checkpoints/last.ckpt")
 | 
			
		||||
 | 
			
		||||
    if cfg_path is None:
 | 
			
		||||
        assert watch_log_dir is not None
 | 
			
		||||
        configs = glob.glob(os.path.join(watch_log_dir, "configs/*-project.yaml"))
 | 
			
		||||
        cfg_path = sorted(configs)[-1]
 | 
			
		||||
 | 
			
		||||
    if W is None:
 | 
			
		||||
        assert H is not None
 | 
			
		||||
        W = H
 | 
			
		||||
    if H is None:
 | 
			
		||||
        assert W is not None
 | 
			
		||||
        H = W
 | 
			
		||||
    shape = [C, H//F, W//F]
 | 
			
		||||
    sampler = Sampler(out_dir, ckpt_path, cfg_path, prompts_path, shape=shape)
 | 
			
		||||
 | 
			
		||||
    checker = Checker(ckpt_path, sampler, interval=interval)
 | 
			
		||||
    checker.check()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(run)
 | 
			
		||||
		Loading…
	
		Reference in a new issue