Merge remote-tracking branch 'github/main' into main

This commit is contained in:
root 2022-06-08 21:04:40 +00:00
commit 5d40de00fb
11 changed files with 872 additions and 40 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 MiB

View file

@ -0,0 +1,131 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.001
linear_end: 0.015
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 16
channels: 16
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.22765929 # magic number
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 16 # not really needed
in_channels: 16
out_channels: 16
model_channels: 320 # TODO: scale model here
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
batch_size: 50 # TODO: max out
num_workers: 4
multinode: True
train:
shards: '{000000..231317}.tar -'
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 256
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards: '{231318..231349}.tar -'
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]
trainer:
replace_sampler_ddp: False # TODO: check this
benchmark: True
val_check_interval: 5000000 # really sorry
num_sanity_val_steps: 0
accumulate_grad_batches: 2 # TODO: want accumulate on? --> wait for final batch-size

View file

@ -16,3 +16,19 @@ class DummyData(Dataset):
letters = string.ascii_lowercase letters = string.ascii_lowercase
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
return {"jpg": x, "txt": y} return {"jpg": x, "txt": y}
class DummyDataWithEmbeddings(Dataset):
def __init__(self, length, size, emb_size):
self.length = length
self.size = size
self.emb_size = emb_size
def __len__(self):
return self.length
def __getitem__(self, i):
x = np.random.randn(*self.size)
y = np.random.randn(*self.emb_size).astype(np.float32)
return {"jpg": x, "txt": y}

View file

@ -4,6 +4,7 @@ import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from einops import rearrange
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
@ -73,8 +74,8 @@ class DDIMSampler(object):
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.,
unconditional_conditioning=None, unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... dynamic_threshold=None,
**kwargs **kwargs
): ):
if conditioning is not None: if conditioning is not None:
@ -106,6 +107,7 @@ class DDIMSampler(object):
log_every_t=log_every_t, log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
) )
return samples, intermediates return samples, intermediates
@ -115,7 +117,7 @@ class DDIMSampler(object):
callback=None, timesteps=None, quantize_denoised=False, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -150,7 +152,8 @@ class DDIMSampler(object):
noise_dropout=noise_dropout, score_corrector=score_corrector, noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
@ -164,7 +167,8 @@ class DDIMSampler(object):
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None): unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
@ -194,6 +198,33 @@ class DDIMSampler(object):
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
# renorm
pred_max = pred_x0.max()
pred_min = pred_x0.min()
pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1
pred_x0 = 2*pred_x0 - 1. # -1 ... 1
s = torch.quantile(
rearrange(pred_x0, 'b ... -> b (...)').abs(),
dynamic_threshold,
dim=-1
)
s.clamp_(min=1.0)
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
# clip by threshold
#pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
# temporary hack: numpy on cpu
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
# re.renorm
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature

View file

@ -12,6 +12,10 @@ class AbstractEncoder(nn.Module):
def encode(self, *args, **kwargs): def encode(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):

View file

@ -0,0 +1,64 @@
A portrait of Abraham Lincoln
A portrait of Barack Obama
A portrait of a nekomimi girl smiling
a portrait of isaac newton the alchemist
A portrait of Friedrich Nietzsche wearing an open double breasted suit with a bowtie
Portrait of a cyberpunk cyborg man wearing alternate reality goggles
Portrait of a woman screaming
A portrait of a man in a flight jacket leaning against a biplane
a cold landscape by Albert Bierstadt
the monument of the ancients by van gogh
the universal library
a vision of paradise. unreal engine
matte painting of cozy underground bunker wholefoods aisle, trending on artstation
illustration of wooly mammoths reclaiming the arctic, trending on artstation
a mountain range in the desert, Provia, Velvia
the gateway between dreams, trending on ArtStation
a cityscape at night
starry night by cyberpunk
A fantasy painting of a city in a deep valley by Ivan Aivazovsky
An oil painting of The New York City Skyline by Natalia Goncharova
a rainy city street in the style of cyberpunk noir, trending on ArtStation
an astral city in the style of cyberpunk noir art deco
The Golden Gate Bridge in the style of art deco
a city on a 70s science fiction novel cover
An oil painting of A Vase Of Flowers
still life oil painting of a smooth silver steel tungsten square cube box by Albrecht Dürer
An oil painting of a bookshelf crammed with books, trending on artstation
An N95 respirator mask in the style of art deco
a surreal and organic stone monument to a plutonium atom
oil painting of a candy dish of glass candies, mints, and other assorted sweets
illustration of a ford model-t in pristine condition, trending on artstation
illustration of DEC minicomputer console monitor retrocomputing teletype interdata PDP-11 univac, trending on artstation
The Rise Of Consciousness
The Human Utility Function
Revolution of the Souls
a good amphetamine spirit
Control The Soul
The Lunatic, The Lover, and The Poet
A Planet Ruled By Angels
the Tower of Babel by J.M.W. Turner
sketch of a 3D printer by Leonardo da Vinci
In The Style Of M.C. Escher
A cup of coffee by Picasso
The US Capitol Building in the style of Kandinsky
A Mysterious Orb by Andy Warhol
The everlasting zero, a glimpse of a million, by Salvador Dali
a painting of a haunted house with Halloween decorations by Giovanni Paolo Panini
a painting of drops of Venus by Vincent van Gogh
ascii art of a man riding a bicycle
cyberpunk noir art deco detective in space
a cyborg angel in the style of ukiyo-e
Hell in the style of pointillism
Moloch in the style of socialist realism
Metaphysics in the style of WPAP
advertisement for a psychedelic virtual reality headset, 16 bit sprite pixel art
a watercolor painting of a Christmas tree
control room monitors televisions screens computers hacker lab, concept art, matte painting, trending on artstation
a group of surgeons wait to cryonically suspend a patient
technological singularity cult by James Gurney
an autogyro flying car, trending on artstation
illustration of airship zepplins in the skies, trending on artstation
watercolor illustration of a martian colony geodesic dome aquaponics farming on the surface, trending on artstation
humanity is killed by AI, by James Gurney
the Vitruvian Man as a propaganda poster for transhumanism

View file

@ -4,6 +4,7 @@ import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange from einops import rearrange
from torchvision.utils import make_grid from torchvision.utils import make_grid
@ -12,6 +13,11 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
@ -51,7 +57,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--ddim_steps", "--ddim_steps",
type=int, type=int,
default=200, default=50,
help="number of ddim sampling steps", help="number of ddim sampling steps",
) )
@ -91,8 +97,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--n_samples", "--n_samples",
type=int, type=int,
default=4, default=8,
help="how many samples to produce for the given prompt", help="how many samples to produce for each given prompt. A.k.a batch size",
) )
parser.add_argument( parser.add_argument(
@ -101,11 +107,35 @@ if __name__ == "__main__":
default=5.0, default=5.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
) )
parser.add_argument(
"--dyn",
type=float,
help="dynamic thresholding from Imagen, in latent space (TODO: try in pixel space with intermediate decode)",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="logs/f8-kl-clip-encoder-256x256-run1/configs/2022-06-01T22-11-40-project.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
help="path to checkpoint of model",
)
opt = parser.parse_args() opt = parser.parse_args()
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
@ -118,48 +148,62 @@ if __name__ == "__main__":
os.makedirs(opt.outdir, exist_ok=True) os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir outpath = opt.outdir
prompt = opt.prompt batch_size = opt.n_samples
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples") sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True) os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path)) base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
all_samples=list()
with torch.no_grad(): with torch.no_grad():
with model.ema_scope(): with model.ema_scope():
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(opt.n_samples * [""])
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
c = model.get_learned_conditioning(opt.n_samples * [prompt]) all_samples = list()
shape = [4, opt.H//8, opt.W//8] for prompts in tqdm(data, desc="data"):
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, uc = None
conditioning=c, if opt.scale != 1.0:
batch_size=opt.n_samples, uc = model.get_learned_conditioning(batch_size * [""])
shape=shape, c = model.get_learned_conditioning(prompts)
verbose=False, shape = [4, opt.H//8, opt.W//8]
unconditional_guidance_scale=opt.scale, samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
unconditional_conditioning=uc, conditioning=c,
eta=opt.ddim_eta) batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
dynamic_threshold=opt.dyn)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png")) Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1 base_count += 1
all_samples.append(x_samples_ddim) all_samples.append(x_samples_ddim)
# additionally, save as grid # additionally, save as grid
grid = torch.stack(all_samples, 0) grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=opt.n_samples) grid = make_grid(grid, nrow=opt.n_samples)
# to image # to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")

View file

@ -0,0 +1,542 @@
import argparse, os, sys, glob
import torch
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from einops import rearrange
from torchvision.utils import make_grid
rescale = lambda x: (x + 1.) / 2.
def bchw_to_st(x):
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
def chw_to_st(x):
return rescale(x.detach().cpu().numpy().transpose(1,2,0))
def custom_get_input(batch, key):
inputs = batch[key].permute(0, 3, 1, 2)
return inputs
def vq_no_codebook_forward(model, x):
h = model.encoder(x)
h = model.quant_conv(h)
h = model.post_quant_conv(h)
xrec = model.decoder(h)
return xrec
def save_img(x, fname):
I = (x.clip(0, 1) * 255).astype(np.uint8)
Image.fromarray(I).save(fname)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"--dataset_config",
type=str,
nargs="?",
default="",
help="path to dataset config"
)
return parser
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
if sd is not None:
model.load_state_dict(sd)
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def get_data(config):
# get data
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
return data
@st.cache(allow_output_mutation=True)
def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False):
# get data
if delete_dataset_params:
st.info("Deleting dataset parameters.")
del config["data"]["params"]["train"]["params"]
del config["data"]["params"]["validation"]["params"]
dsets = get_data(config) # calls data.config ...
# now load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"],
gpu=gpu,
eval_mode=eval_mode)["model"]
return dsets, model, global_step
@torch.no_grad()
def get_image_embeddings(model, dset, used_codebook, used_indices):
import plotly.graph_objects as go
batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4)
start_index = st.number_input("Start index", value=0,
min_value=0,
max_value=len(dset) - batch_size)
if st.sidebar.button("Sample Batch"):
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index + batch_size))
st.write(f"Indices: {indices}")
batch = default_collate([dset[i] for i in indices])
x = model.get_input(batch, "image")
x = x.to(model.device)
# get reconstruction from non-quantized and quantized, compare
z_pre_quant = model.encode_to_prequant(x)
z_quant, emb_loss, info = model.quantize(z_pre_quant)
indices = info[2].detach().cpu().numpy()
#indices = rearrange(indices, '(b d) -> b d', b=batch_size)
unique_indices = np.unique(indices)
st.write(f"Unique indices in batch: {unique_indices.shape[0]}")
x1 = used_codebook[:, 0].cpu().numpy()
x2 = used_codebook[:, 1].cpu().numpy()
x3 = used_codebook[:, 2].cpu().numpy()
zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1.,
color="Blue")),
name="All Used Codebook Entries"),
])
trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)),
name="Pre-Quant Codes")
trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers',
marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes")
fig.add_trace(trace2)
fig.add_trace(trace3)
fig.update_layout(
autosize=False,
width=1000,
height=1000,
)
x_rec_no_quant = model.decode(z_pre_quant)
x_rec_quant = model.decode(z_quant)
delta_x = x_rec_no_quant - x_rec_quant
st.text("Fitting Gaussian...")
h, w = z_quant.shape[2], z_quant.shape[3]
from sklearn.mixture import GaussianMixture
gaussian = GaussianMixture(n_components=1)
gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy())
samples, _ = gaussian.sample(n_samples=batch_size*h*w)
samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3)
samples = rearrange(samples, 'b h w c -> b c h w')
samples = torch.tensor(samples).to(z_quant)
samples, _, _ = model.quantize(samples)
x_sample = model.decode(samples)
all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
grid = make_grid(all_img, nrow=5)
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
st.image(chw_to_st(grid), clamp=True, output_format="PNG")
st.write(fig)
# 2d projections
import matplotlib.pyplot as plt
pairs = [(1, 0), (2, 0), (2, 1)]
fig2, ax = plt.subplots(1, 3, figsize=(21, 7))
for d in range(3):
d1, d2 = pairs[d]
#ax[d].scatter(used_codebook[:, d1].cpu().numpy(),
# used_codebook[:, d2].cpu().numpy(),
# label="All Used Codebook Entries", s=10.0, c=used_indices)
ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
label="Quantized Codes", alpha=0.9, s=8.0, c=indices)
ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices)
ax[d].set_title(f"dim {d2} vs dim {d1}")
ax[d].legend()
st.write(fig2)
plt.close()
# from scipy.spatial import Voronoi, voronoi_plot_2d
# fig3 = plt.figure(figsize=(10,10))
# points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy()
# vor = Voronoi(points)
# # plot
# voronoi_plot_2d(vor)
# # colorize
# for region in vor.regions:
# if not -1 in region:
# polygon = [vor.vertices[i] for i in region]
# plt.fill(*zip(*polygon))
# plt.savefig("voronoi_test.png")
# st.write(fig3)
@torch.no_grad()
def get_used_indices(model, dset, batch_size=20):
dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False)
data = list()
info = st.empty()
for i, batch in enumerate(dloader):
x = model.get_input(batch, "image")
x = x.to(model.device)
zq, _, zi = model.encode(x)
indices = zi[2]
indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy()
data.append(indices)
unique = np.unique(data)
info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}")
unique = np.unique(data)
#np.save(outpath, unique)
st.write(f"end of data: found **{unique.size} unique indices.**")
print(f"end of data: found {unique.size} unique indices.")
return unique
def visualize3d(codebook, used_indices):
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
codebook = codebook.cpu().numpy()
selected_codebook = codebook[used_indices, :]
z_dim = codebook.shape[1]
assert z_dim == 3
pairs = [(1,0), (2,0), (2,1)]
fig, ax = plt.subplots(1,3, figsize=(15,5))
for d in range(3):
d1, d2 = pairs[d]
ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2])
ax[d].set_title(f"dim {d2} vs dim {d1}")
st.write(fig)
# # plot 3D
# fig = plt.figure(1)
# ax = Axes3D(fig)
# ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries")
# ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries")
# plt.legend()
# #st.write(fig)
# st.pyplot(fig)
# plot histogram of vector norms
fig = plt.figure(2, figsize=(6,5))
norms = np.linalg.norm(selected_codebook, axis=1)
plt.hist(norms, bins=100, edgecolor="black", lw=1.1)
plt.title("Distribution of norms of used codebook entries")
st.write(fig)
# plot 3D with plotly
import pandas as pd
import plotly.graph_objects as go
x = selected_codebook[:, 0]
y = selected_codebook[:, 1]
z = selected_codebook[:, 2]
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1.,
color="Blue"))
)
]
)
fig.update_layout(
autosize=False,
width=1000,
height=1000,
)
st.write(fig)
@torch.no_grad()
def get_fixed_points(model, dset):
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
clip_decoded = st.checkbox("Clip decoded image", False)
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
if st.sidebar.button("Sample Batch"):
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index + batch_size))
st.write(f"Indices: {indices}")
batch = default_collate([dset[i] for i in indices])
x = model.get_input(batch, "image")
x = x.to(model.device)
progress = st.empty()
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
image_progress = st.empty() # TODO
input = x
img_quant = x
img_noquant = x
delta_img = img_quant - img_noquant
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
def display(input, img_quant, img_noquant, delta_img):
all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
grid = make_grid(all_img, nrow=4)
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
display(input, img_quant, img_noquant, delta_img)
for n in range(n_iter):
# get reconstruction from non-quantized and quantized, compare via iteration
# quantized_stream
z_pre_quant = model.encode_to_prequant(img_quant)
z_quant, emb_loss, info = model.quantize(z_pre_quant)
# non_quantized stream
z_noquant = model.encode_to_prequant(img_noquant)
img_quant = model.decode(z_quant)
img_noquant = model.decode(z_noquant)
if clip_decoded:
img_quant = torch.clamp(img_quant, -1., 1.)
img_noquant = torch.clamp(img_noquant, -1., 1.)
if quantize_decoded:
device = img_quant.device
img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
delta_img = img_quant - img_noquant
display(input, img_quant, img_noquant, delta_img)
progress_cb(n + 1)
@torch.no_grad()
def get_fixed_points_kl_ae(model, dset):
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
clip_decoded = st.checkbox("Clip decoded image", False)
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
sample_posterior = st.checkbox("Sample from encoder posterior", False)
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
if st.sidebar.button("Sample Batch"):
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index + batch_size))
st.write(f"Indices: {indices}")
batch = default_collate([dset[i] for i in indices])
x = model.get_input(batch, "image")
x = x.to(model.device)
progress = st.empty()
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **")
image_progress = st.empty()
input = x
img_noquant = x
delta_img = input - img_noquant
def display(input, img_noquant, delta_img):
all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
grid = make_grid(all_img, nrow=3)
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
fig, ax = plt.subplots()
distribution_progress = st.empty()
def display_latent_distribution(latent_z, alpha=1., title=""):
flatz = latent_z.reshape(-1).cpu().detach().numpy()
#fig, ax = plt.subplots()
ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black")
ax.set_title(title)
distribution_progress.pyplot(fig)
display(input, img_noquant, delta_img)
for n in range(n_iter):
# get reconstructions
posterior = model.encode(img_noquant)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
if n==0:
flatz_init = z.reshape(-1).cpu().detach().numpy()
std_init = flatz_init.std()
max_init, min_init = flatz_init.max(), flatz_init.min()
display_latent_distribution(z, alpha=np.sqrt(1/(n+1)),
title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}")
img_noquant = model.decode(z)
if clip_decoded:
img_noquant = torch.clamp(img_noquant, -1., 1.)
if quantize_decoded:
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device)
delta_img = img_noquant - input
display(input, img_noquant, delta_img)
progress_cb(n + 1)
if __name__ == "__main__":
from ldm.models.autoencoder import AutoencoderKL
# VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
if opt.dataset_config:
dcfg = OmegaConf.load(opt.dataset_config)
print("Replacing data config with:")
print(dcfg.pretty())
dcfg = OmegaConf.to_container(dcfg)
config["data"] = dcfg["data"]
st.sidebar.text(ckpt)
gs = st.sidebar.empty()
gs.text(f"Global step: ?")
st.sidebar.text("Options")
gpu = st.sidebar.checkbox("GPU", value=True)
eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
show_config = st.sidebar.checkbox("Show Config", value=False)
if show_config:
st.info("Checkpoint: {}".format(ckpt))
st.json(OmegaConf.to_container(config))
delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.")
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode,
delete_dataset_params=delelete_dataset_parameters)
gs.text(f"Global step: {global_step}")
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1])
dset = dsets.datasets[split]
batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20)
num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5)
data_size = batch_size*num_batches
dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False))
if not isinstance(model, AutoencoderKL):
# VQ MODEL
codebook = model.quantize.embedding.weight.data
st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**")
st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**")
st.write("**Select ONE of the following options**")
if st.checkbox("Show Codebook Statistics", False):
used_indices = get_used_indices(model, dset, batch_size=batch_size)
visualize3d(codebook, used_indices)
if st.checkbox("Show Batch Encodings", False):
used_indices = get_used_indices(model, dset, batch_size=batch_size)
get_image_embeddings(model, dset, codebook[used_indices, :], used_indices)
if st.checkbox("Show Fixed Points of Data", False):
get_fixed_points(model, dset)
else:
st.info("Detected a KL model")
get_fixed_points_kl_ae(model, dset)