diff --git a/assets/samples/grid-0001.png b/assets/samples/grid-0001.png new file mode 100644 index 0000000..0acf3e9 Binary files /dev/null and b/assets/samples/grid-0001.png differ diff --git a/assets/samples/grid-0006.png b/assets/samples/grid-0006.png new file mode 100644 index 0000000..8ce3a1d Binary files /dev/null and b/assets/samples/grid-0006.png differ diff --git a/assets/samples/grid-0007.png b/assets/samples/grid-0007.png new file mode 100644 index 0000000..f3dc514 Binary files /dev/null and b/assets/samples/grid-0007.png differ diff --git a/assets/samples/grid-0008.png b/assets/samples/grid-0008.png new file mode 100644 index 0000000..ff52403 Binary files /dev/null and b/assets/samples/grid-0008.png differ diff --git a/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml b/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml new file mode 100644 index 0000000..d6420a9 --- /dev/null +++ b/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml @@ -0,0 +1,131 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.001 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 16 + channels: 16 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.22765929 # magic number + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 16 # not really needed + in_channels: 16 + out_channels: 16 + model_channels: 320 # TODO: scale model here + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ckpt_path: "models/first_stage_models/kl-f16/model.ckpt" + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/" + batch_size: 50 # TODO: max out + num_workers: 4 + multinode: True + train: + shards: '{000000..231317}.tar -' + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 256 + + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: '{231318..231349}.tar -' + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + replace_sampler_ddp: False # TODO: check this + benchmark: True + val_check_interval: 5000000 # really sorry + num_sanity_val_steps: 0 + accumulate_grad_batches: 2 # TODO: want accumulate on? --> wait for final batch-size diff --git a/ldm/data/dummy.py b/ldm/data/dummy.py index be295a1..3b74a77 100644 --- a/ldm/data/dummy.py +++ b/ldm/data/dummy.py @@ -16,3 +16,19 @@ class DummyData(Dataset): letters = string.ascii_lowercase y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) return {"jpg": x, "txt": y} + + +class DummyDataWithEmbeddings(Dataset): + def __init__(self, length, size, emb_size): + self.length = length + self.size = size + self.emb_size = emb_size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + y = np.random.randn(*self.emb_size).astype(np.float32) + return {"jpg": x, "txt": y} + diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index edf1eaf..7d6cb48 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -4,6 +4,7 @@ import torch import numpy as np from tqdm import tqdm from functools import partial +from einops import rearrange from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like @@ -73,8 +74,8 @@ class DDIMSampler(object): x_T=None, log_every_t=100, unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, **kwargs ): if conditioning is not None: @@ -106,6 +107,7 @@ class DDIMSampler(object): log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, ) return samples, intermediates @@ -115,7 +117,7 @@ class DDIMSampler(object): callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -150,7 +152,8 @@ class DDIMSampler(object): noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -164,7 +167,8 @@ class DDIMSampler(object): @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None): + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): b, *_, device = *x.shape, x.device if unconditional_conditioning is None or unconditional_guidance_scale == 1.: @@ -194,6 +198,33 @@ class DDIMSampler(object): pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + # renorm + pred_max = pred_x0.max() + pred_min = pred_x0.min() + pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1 + pred_x0 = 2*pred_x0 - 1. # -1 ... 1 + + s = torch.quantile( + rearrange(pred_x0, 'b ... -> b (...)').abs(), + dynamic_threshold, + dim=-1 + ) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + #pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 + pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range + # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index a62b8e7..68260c3 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -12,6 +12,10 @@ class AbstractEncoder(nn.Module): def encode(self, *args, **kwargs): raise NotImplementedError +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x class ClassEmbedder(nn.Module): diff --git a/scripts/prompts/prompts-with-wings.txt b/scripts/prompts/prompts-with-wings.txt new file mode 100644 index 0000000..7da28b2 --- /dev/null +++ b/scripts/prompts/prompts-with-wings.txt @@ -0,0 +1,64 @@ +A portrait of Abraham Lincoln +A portrait of Barack Obama +A portrait of a nekomimi girl smiling +a portrait of isaac newton the alchemist +A portrait of Friedrich Nietzsche wearing an open double breasted suit with a bowtie +Portrait of a cyberpunk cyborg man wearing alternate reality goggles +Portrait of a woman screaming +A portrait of a man in a flight jacket leaning against a biplane +a cold landscape by Albert Bierstadt +the monument of the ancients by van gogh +the universal library +a vision of paradise. unreal engine +matte painting of cozy underground bunker wholefoods aisle, trending on artstation +illustration of wooly mammoths reclaiming the arctic, trending on artstation +a mountain range in the desert, Provia, Velvia +the gateway between dreams, trending on ArtStation +a cityscape at night +starry night by cyberpunk +A fantasy painting of a city in a deep valley by Ivan Aivazovsky +An oil painting of The New York City Skyline by Natalia Goncharova +a rainy city street in the style of cyberpunk noir, trending on ArtStation +an astral city in the style of cyberpunk noir art deco +The Golden Gate Bridge in the style of art deco +a city on a 70s science fiction novel cover +An oil painting of A Vase Of Flowers +still life oil painting of a smooth silver steel tungsten square cube box by Albrecht Dürer +An oil painting of a bookshelf crammed with books, trending on artstation +An N95 respirator mask in the style of art deco +a surreal and organic stone monument to a plutonium atom +oil painting of a candy dish of glass candies, mints, and other assorted sweets +illustration of a ford model-t in pristine condition, trending on artstation +illustration of DEC minicomputer console monitor retrocomputing teletype interdata PDP-11 univac, trending on artstation +The Rise Of Consciousness +The Human Utility Function +Revolution of the Souls +a good amphetamine spirit +Control The Soul +The Lunatic, The Lover, and The Poet +A Planet Ruled By Angels +the Tower of Babel by J.M.W. Turner +sketch of a 3D printer by Leonardo da Vinci +In The Style Of M.C. Escher +A cup of coffee by Picasso +The US Capitol Building in the style of Kandinsky +A Mysterious Orb by Andy Warhol +The everlasting zero, a glimpse of a million, by Salvador Dali +a painting of a haunted house with Halloween decorations by Giovanni Paolo Panini +a painting of drops of Venus by Vincent van Gogh +ascii art of a man riding a bicycle +cyberpunk noir art deco detective in space +a cyborg angel in the style of ukiyo-e +Hell in the style of pointillism +Moloch in the style of socialist realism +Metaphysics in the style of WPAP +advertisement for a psychedelic virtual reality headset, 16 bit sprite pixel art +a watercolor painting of a Christmas tree +control room monitors televisions screens computers hacker lab, concept art, matte painting, trending on artstation +a group of surgeons wait to cryonically suspend a patient +technological singularity cult by James Gurney +an autogyro flying car, trending on artstation +illustration of airship zepplins in the skies, trending on artstation +watercolor illustration of a martian colony geodesic dome aquaponics farming on the surface, trending on artstation +humanity is killed by AI, by James Gurney +the Vitruvian Man as a propaganda poster for transhumanism \ No newline at end of file diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 613de5e..ea94f86 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -4,6 +4,7 @@ import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange +from itertools import islice from einops import rearrange from torchvision.utils import make_grid @@ -12,6 +13,11 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -51,7 +57,7 @@ if __name__ == "__main__": parser.add_argument( "--ddim_steps", type=int, - default=200, + default=50, help="number of ddim sampling steps", ) @@ -91,8 +97,8 @@ if __name__ == "__main__": parser.add_argument( "--n_samples", type=int, - default=4, - help="how many samples to produce for the given prompt", + default=8, + help="how many samples to produce for each given prompt. A.k.a batch size", ) parser.add_argument( @@ -101,11 +107,35 @@ if __name__ == "__main__": default=5.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) + + parser.add_argument( + "--dyn", + type=float, + help="dynamic thresholding from Imagen, in latent space (TODO: try in pixel space with intermediate decode)", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="logs/f8-kl-clip-encoder-256x256-run1/configs/2022-06-01T22-11-40-project.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt", + help="path to checkpoint of model", + ) opt = parser.parse_args() - config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic - model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -118,48 +148,62 @@ if __name__ == "__main__": os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir - prompt = opt.prompt + batch_size = opt.n_samples + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + - all_samples=list() with torch.no_grad(): with model.ema_scope(): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(opt.n_samples * [""]) for n in trange(opt.n_iter, desc="Sampling"): - c = model.get_learned_conditioning(opt.n_samples * [prompt]) - shape = [4, opt.H//8, opt.W//8] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta) + all_samples = list() + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + c = model.get_learned_conditioning(prompts) + shape = [4, opt.H//8, opt.W//8] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + dynamic_threshold=opt.dyn) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png")) - base_count += 1 - all_samples.append(x_samples_ddim) + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=opt.n_samples) + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=opt.n_samples) - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png')) + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 - print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/scripts/vqgan_codebook_visualizer.py b/scripts/vqgan_codebook_visualizer.py new file mode 100644 index 0000000..24e4986 --- /dev/null +++ b/scripts/vqgan_codebook_visualizer.py @@ -0,0 +1,542 @@ +import argparse, os, sys, glob +import torch +import numpy as np +import streamlit as st +import matplotlib.pyplot as plt +from omegaconf import OmegaConf +from PIL import Image +from main import instantiate_from_config, DataModuleFromConfig +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from einops import rearrange +from torchvision.utils import make_grid + + +rescale = lambda x: (x + 1.) / 2. + + +def bchw_to_st(x): + return rescale(x.detach().cpu().numpy().transpose(0,2,3,1)) + +def chw_to_st(x): + return rescale(x.detach().cpu().numpy().transpose(1,2,0)) + + +def custom_get_input(batch, key): + inputs = batch[key].permute(0, 3, 1, 2) + return inputs + + +def vq_no_codebook_forward(model, x): + h = model.encoder(x) + h = model.quant_conv(h) + h = model.post_quant_conv(h) + xrec = model.decoder(h) + return xrec + + +def save_img(x, fname): + I = (x.clip(0, 1) * 255).astype(np.uint8) + Image.fromarray(I).save(fname) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-c", + "--config", + nargs="?", + metavar="single_config.yaml", + help="path to single config. If specified, base configs will be ignored " + "(except for the last one if left unspecified).", + const=True, + default="", + ) + parser.add_argument( + "--dataset_config", + type=str, + nargs="?", + default="", + help="path to dataset config" + ) + return parser + + +def load_model_from_config(config, sd, gpu=True, eval_mode=True): + model = instantiate_from_config(config) + if sd is not None: + model.load_state_dict(sd) + if gpu: + model.cuda() + if eval_mode: + model.eval() + return {"model": model} + + +def get_data(config): + # get data + data = instantiate_from_config(config.data) + data.prepare_data() + data.setup() + return data + + +@st.cache(allow_output_mutation=True) +def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False): + # get data + if delete_dataset_params: + st.info("Deleting dataset parameters.") + del config["data"]["params"]["train"]["params"] + del config["data"]["params"]["validation"]["params"] + + dsets = get_data(config) # calls data.config ... + + # now load the specified checkpoint + if ckpt: + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"], + gpu=gpu, + eval_mode=eval_mode)["model"] + return dsets, model, global_step + + +@torch.no_grad() +def get_image_embeddings(model, dset, used_codebook, used_indices): + import plotly.graph_objects as go + batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4) + start_index = st.number_input("Start index", value=0, + min_value=0, + max_value=len(dset) - batch_size) + if st.sidebar.button("Sample Batch"): + indices = np.random.choice(len(dset), batch_size) + else: + indices = list(range(start_index, start_index + batch_size)) + + st.write(f"Indices: {indices}") + batch = default_collate([dset[i] for i in indices]) + x = model.get_input(batch, "image") + x = x.to(model.device) + + # get reconstruction from non-quantized and quantized, compare + z_pre_quant = model.encode_to_prequant(x) + z_quant, emb_loss, info = model.quantize(z_pre_quant) + indices = info[2].detach().cpu().numpy() + #indices = rearrange(indices, '(b d) -> b d', b=batch_size) + unique_indices = np.unique(indices) + st.write(f"Unique indices in batch: {unique_indices.shape[0]}") + + x1 = used_codebook[:, 0].cpu().numpy() + x2 = used_codebook[:, 1].cpu().numpy() + x3 = used_codebook[:, 2].cpu().numpy() + + zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy() + zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy() + zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy() + + zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy() + zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy() + zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy() + + fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1., + color="Blue")), + name="All Used Codebook Entries"), + + ]) + trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)), + name="Pre-Quant Codes") + trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers', + marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes") + + fig.add_trace(trace2) + fig.add_trace(trace3) + + fig.update_layout( + autosize=False, + width=1000, + height=1000, + ) + + x_rec_no_quant = model.decode(z_pre_quant) + x_rec_quant = model.decode(z_quant) + delta_x = x_rec_no_quant - x_rec_quant + + st.text("Fitting Gaussian...") + h, w = z_quant.shape[2], z_quant.shape[3] + from sklearn.mixture import GaussianMixture + gaussian = GaussianMixture(n_components=1) + gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy()) + samples, _ = gaussian.sample(n_samples=batch_size*h*w) + samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3) + samples = rearrange(samples, 'b h w c -> b c h w') + samples = torch.tensor(samples).to(z_quant) + samples, _, _ = model.quantize(samples) + x_sample = model.decode(samples) + + all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W + + all_img = rearrange(all_img, 'n b c h w -> b n c h w') + all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') + grid = make_grid(all_img, nrow=5) + + st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **") + st.image(chw_to_st(grid), clamp=True, output_format="PNG") + + st.write(fig) + # 2d projections + import matplotlib.pyplot as plt + pairs = [(1, 0), (2, 0), (2, 1)] + fig2, ax = plt.subplots(1, 3, figsize=(21, 7)) + for d in range(3): + d1, d2 = pairs[d] + #ax[d].scatter(used_codebook[:, d1].cpu().numpy(), + # used_codebook[:, d2].cpu().numpy(), + # label="All Used Codebook Entries", s=10.0, c=used_indices) + ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(), + rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(), + label="Quantized Codes", alpha=0.9, s=8.0, c=indices) + ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(), + rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(), + label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices) + ax[d].set_title(f"dim {d2} vs dim {d1}") + ax[d].legend() + st.write(fig2) + + plt.close() + + # from scipy.spatial import Voronoi, voronoi_plot_2d + # fig3 = plt.figure(figsize=(10,10)) + # points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy() + # vor = Voronoi(points) + # # plot + # voronoi_plot_2d(vor) + # # colorize + # for region in vor.regions: + # if not -1 in region: + # polygon = [vor.vertices[i] for i in region] + # plt.fill(*zip(*polygon)) + # plt.savefig("voronoi_test.png") + # st.write(fig3) + + +@torch.no_grad() +def get_used_indices(model, dset, batch_size=20): + dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False) + data = list() + info = st.empty() + for i, batch in enumerate(dloader): + x = model.get_input(batch, "image") + x = x.to(model.device) + zq, _, zi = model.encode(x) + indices = zi[2] + indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy() + data.append(indices) + + unique = np.unique(data) + info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}") + + unique = np.unique(data) + #np.save(outpath, unique) + st.write(f"end of data: found **{unique.size} unique indices.**") + print(f"end of data: found {unique.size} unique indices.") + return unique + + +def visualize3d(codebook, used_indices): + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + codebook = codebook.cpu().numpy() + selected_codebook = codebook[used_indices, :] + z_dim = codebook.shape[1] + assert z_dim == 3 + + pairs = [(1,0), (2,0), (2,1)] + fig, ax = plt.subplots(1,3, figsize=(15,5)) + for d in range(3): + d1, d2 = pairs[d] + ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2]) + ax[d].set_title(f"dim {d2} vs dim {d1}") + st.write(fig) + + # # plot 3D + # fig = plt.figure(1) + # ax = Axes3D(fig) + # ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries") + # ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries") + # plt.legend() + # #st.write(fig) + # st.pyplot(fig) + + # plot histogram of vector norms + fig = plt.figure(2, figsize=(6,5)) + norms = np.linalg.norm(selected_codebook, axis=1) + plt.hist(norms, bins=100, edgecolor="black", lw=1.1) + plt.title("Distribution of norms of used codebook entries") + st.write(fig) + + # plot 3D with plotly + import pandas as pd + import plotly.graph_objects as go + x = selected_codebook[:, 0] + y = selected_codebook[:, 1] + z = selected_codebook[:, 2] + fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1., + color="Blue")) + ) + ] + ) + + fig.update_layout( + autosize=False, + width=1000, + height=1000, + ) + st.write(fig) + + +@torch.no_grad() +def get_fixed_points(model, dset): + n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25) + batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4) + start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) + clip_decoded = st.checkbox("Clip decoded image", False) + quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False) + factor = st.sidebar.number_input("image size", value=1., min_value=0.1) + if st.sidebar.button("Sample Batch"): + indices = np.random.choice(len(dset), batch_size) + else: + indices = list(range(start_index, start_index + batch_size)) + + st.write(f"Indices: {indices}") + batch = default_collate([dset[i] for i in indices]) + x = model.get_input(batch, "image") + x = x.to(model.device) + + progress = st.empty() + progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}") + + image_progress = st.empty() # TODO + + input = x + img_quant = x + img_noquant = x + delta_img = img_quant - img_noquant + + st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **") + + def display(input, img_quant, img_noquant, delta_img): + all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W + all_img = rearrange(all_img, 'n b c h w -> b n c h w') + all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') + grid = make_grid(all_img, nrow=4) + image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2])) + + display(input, img_quant, img_noquant, delta_img) + for n in range(n_iter): + # get reconstruction from non-quantized and quantized, compare via iteration + + # quantized_stream + z_pre_quant = model.encode_to_prequant(img_quant) + z_quant, emb_loss, info = model.quantize(z_pre_quant) + + # non_quantized stream + z_noquant = model.encode_to_prequant(img_noquant) + + img_quant = model.decode(z_quant) + img_noquant = model.decode(z_noquant) + if clip_decoded: + img_quant = torch.clamp(img_quant, -1., 1.) + img_noquant = torch.clamp(img_noquant, -1., 1.) + if quantize_decoded: + device = img_quant.device + img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device) + img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device) + delta_img = img_quant - img_noquant + display(input, img_quant, img_noquant, delta_img) + progress_cb(n + 1) + + +@torch.no_grad() +def get_fixed_points_kl_ae(model, dset): + n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25) + batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4) + start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) + clip_decoded = st.checkbox("Clip decoded image", False) + quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False) + sample_posterior = st.checkbox("Sample from encoder posterior", False) + factor = st.sidebar.number_input("image size", value=1., min_value=0.1) + if st.sidebar.button("Sample Batch"): + indices = np.random.choice(len(dset), batch_size) + else: + indices = list(range(start_index, start_index + batch_size)) + + st.write(f"Indices: {indices}") + batch = default_collate([dset[i] for i in indices]) + x = model.get_input(batch, "image") + x = x.to(model.device) + + progress = st.empty() + progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}") + + st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **") + image_progress = st.empty() + + input = x + img_noquant = x + delta_img = input - img_noquant + + def display(input, img_noquant, delta_img): + all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W + all_img = rearrange(all_img, 'n b c h w -> b n c h w') + all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') + grid = make_grid(all_img, nrow=3) + image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2])) + + fig, ax = plt.subplots() + + distribution_progress = st.empty() + def display_latent_distribution(latent_z, alpha=1., title=""): + flatz = latent_z.reshape(-1).cpu().detach().numpy() + #fig, ax = plt.subplots() + ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black") + ax.set_title(title) + distribution_progress.pyplot(fig) + + display(input, img_noquant, delta_img) + for n in range(n_iter): + # get reconstructions + + posterior = model.encode(img_noquant) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + if n==0: + flatz_init = z.reshape(-1).cpu().detach().numpy() + std_init = flatz_init.std() + max_init, min_init = flatz_init.max(), flatz_init.min() + + display_latent_distribution(z, alpha=np.sqrt(1/(n+1)), + title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}") + + img_noquant = model.decode(z) + if clip_decoded: + img_noquant = torch.clamp(img_noquant, -1., 1.) + if quantize_decoded: + img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device) + delta_img = img_noquant - input + display(input, img_noquant, delta_img) + progress_cb(n + 1) + + + +if __name__ == "__main__": + from ldm.models.autoencoder import AutoencoderKL + # VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL + sys.path.append(os.getcwd()) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + + ckpt = None + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + try: + idx = len(paths)-paths[::-1].index("logs")+1 + except ValueError: + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) + opt.base = base_configs+opt.base + + if opt.config: + if type(opt.config) == str: + opt.base = [opt.config] + else: + opt.base = [opt.base[-1]] + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + if opt.dataset_config: + dcfg = OmegaConf.load(opt.dataset_config) + print("Replacing data config with:") + print(dcfg.pretty()) + dcfg = OmegaConf.to_container(dcfg) + config["data"] = dcfg["data"] + + st.sidebar.text(ckpt) + gs = st.sidebar.empty() + gs.text(f"Global step: ?") + st.sidebar.text("Options") + gpu = st.sidebar.checkbox("GPU", value=True) + eval_mode = st.sidebar.checkbox("Eval Mode", value=True) + show_config = st.sidebar.checkbox("Show Config", value=False) + if show_config: + st.info("Checkpoint: {}".format(ckpt)) + st.json(OmegaConf.to_container(config)) + + delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.") + dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode, + delete_dataset_params=delelete_dataset_parameters) + gs.text(f"Global step: {global_step}") + + split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1]) + dset = dsets.datasets[split] + + batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20) + num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5) + data_size = batch_size*num_batches + dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False)) + + if not isinstance(model, AutoencoderKL): + # VQ MODEL + codebook = model.quantize.embedding.weight.data + st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**") + st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**") + st.write("**Select ONE of the following options**") + + if st.checkbox("Show Codebook Statistics", False): + used_indices = get_used_indices(model, dset, batch_size=batch_size) + visualize3d(codebook, used_indices) + if st.checkbox("Show Batch Encodings", False): + used_indices = get_used_indices(model, dset, batch_size=batch_size) + get_image_embeddings(model, dset, codebook[used_indices, :], used_indices) + if st.checkbox("Show Fixed Points of Data", False): + get_fixed_points(model, dset) + + else: + st.info("Detected a KL model") + get_fixed_points_kl_ae(model, dset) \ No newline at end of file