Merge remote-tracking branch 'github/main' into main
This commit is contained in:
commit
5d40de00fb
11 changed files with 872 additions and 40 deletions
BIN
assets/samples/grid-0001.png
Normal file
BIN
assets/samples/grid-0001.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.8 MiB |
BIN
assets/samples/grid-0006.png
Normal file
BIN
assets/samples/grid-0006.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.6 MiB |
BIN
assets/samples/grid-0007.png
Normal file
BIN
assets/samples/grid-0007.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.9 MiB |
BIN
assets/samples/grid-0008.png
Normal file
BIN
assets/samples/grid-0008.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.8 MiB |
|
@ -0,0 +1,131 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 16
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 16 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320 # TODO: scale model here
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 50 # TODO: max out
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
replace_sampler_ddp: False # TODO: check this
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2 # TODO: want accumulate on? --> wait for final batch-size
|
|
@ -16,3 +16,19 @@ class DummyData(Dataset):
|
||||||
letters = string.ascii_lowercase
|
letters = string.ascii_lowercase
|
||||||
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
|
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
|
||||||
return {"jpg": x, "txt": y}
|
return {"jpg": x, "txt": y}
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataWithEmbeddings(Dataset):
|
||||||
|
def __init__(self, length, size, emb_size):
|
||||||
|
self.length = length
|
||||||
|
self.size = size
|
||||||
|
self.emb_size = emb_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
x = np.random.randn(*self.size)
|
||||||
|
y = np.random.randn(*self.emb_size).astype(np.float32)
|
||||||
|
return {"jpg": x, "txt": y}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
|
||||||
|
@ -73,8 +74,8 @@ class DDIMSampler(object):
|
||||||
x_T=None,
|
x_T=None,
|
||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_guidance_scale=1.,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
dynamic_threshold=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
|
@ -106,6 +107,7 @@ class DDIMSampler(object):
|
||||||
log_every_t=log_every_t,
|
log_every_t=log_every_t,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@ -115,7 +117,7 @@ class DDIMSampler(object):
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -150,7 +152,8 @@ class DDIMSampler(object):
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
@ -164,7 +167,8 @@ class DDIMSampler(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
@ -194,6 +198,33 @@ class DDIMSampler(object):
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
# renorm
|
||||||
|
pred_max = pred_x0.max()
|
||||||
|
pred_min = pred_x0.min()
|
||||||
|
pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1
|
||||||
|
pred_x0 = 2*pred_x0 - 1. # -1 ... 1
|
||||||
|
|
||||||
|
s = torch.quantile(
|
||||||
|
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
||||||
|
dynamic_threshold,
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
s.clamp_(min=1.0)
|
||||||
|
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||||
|
|
||||||
|
# clip by threshold
|
||||||
|
#pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||||
|
|
||||||
|
# temporary hack: numpy on cpu
|
||||||
|
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
||||||
|
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||||
|
|
||||||
|
# re.renorm
|
||||||
|
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||||
|
pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range
|
||||||
|
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
|
|
@ -12,6 +12,10 @@ class AbstractEncoder(nn.Module):
|
||||||
def encode(self, *args, **kwargs):
|
def encode(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class IdentityEncoder(AbstractEncoder):
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ClassEmbedder(nn.Module):
|
class ClassEmbedder(nn.Module):
|
||||||
|
|
64
scripts/prompts/prompts-with-wings.txt
Normal file
64
scripts/prompts/prompts-with-wings.txt
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
A portrait of Abraham Lincoln
|
||||||
|
A portrait of Barack Obama
|
||||||
|
A portrait of a nekomimi girl smiling
|
||||||
|
a portrait of isaac newton the alchemist
|
||||||
|
A portrait of Friedrich Nietzsche wearing an open double breasted suit with a bowtie
|
||||||
|
Portrait of a cyberpunk cyborg man wearing alternate reality goggles
|
||||||
|
Portrait of a woman screaming
|
||||||
|
A portrait of a man in a flight jacket leaning against a biplane
|
||||||
|
a cold landscape by Albert Bierstadt
|
||||||
|
the monument of the ancients by van gogh
|
||||||
|
the universal library
|
||||||
|
a vision of paradise. unreal engine
|
||||||
|
matte painting of cozy underground bunker wholefoods aisle, trending on artstation
|
||||||
|
illustration of wooly mammoths reclaiming the arctic, trending on artstation
|
||||||
|
a mountain range in the desert, Provia, Velvia
|
||||||
|
the gateway between dreams, trending on ArtStation
|
||||||
|
a cityscape at night
|
||||||
|
starry night by cyberpunk
|
||||||
|
A fantasy painting of a city in a deep valley by Ivan Aivazovsky
|
||||||
|
An oil painting of The New York City Skyline by Natalia Goncharova
|
||||||
|
a rainy city street in the style of cyberpunk noir, trending on ArtStation
|
||||||
|
an astral city in the style of cyberpunk noir art deco
|
||||||
|
The Golden Gate Bridge in the style of art deco
|
||||||
|
a city on a 70s science fiction novel cover
|
||||||
|
An oil painting of A Vase Of Flowers
|
||||||
|
still life oil painting of a smooth silver steel tungsten square cube box by Albrecht Dürer
|
||||||
|
An oil painting of a bookshelf crammed with books, trending on artstation
|
||||||
|
An N95 respirator mask in the style of art deco
|
||||||
|
a surreal and organic stone monument to a plutonium atom
|
||||||
|
oil painting of a candy dish of glass candies, mints, and other assorted sweets
|
||||||
|
illustration of a ford model-t in pristine condition, trending on artstation
|
||||||
|
illustration of DEC minicomputer console monitor retrocomputing teletype interdata PDP-11 univac, trending on artstation
|
||||||
|
The Rise Of Consciousness
|
||||||
|
The Human Utility Function
|
||||||
|
Revolution of the Souls
|
||||||
|
a good amphetamine spirit
|
||||||
|
Control The Soul
|
||||||
|
The Lunatic, The Lover, and The Poet
|
||||||
|
A Planet Ruled By Angels
|
||||||
|
the Tower of Babel by J.M.W. Turner
|
||||||
|
sketch of a 3D printer by Leonardo da Vinci
|
||||||
|
In The Style Of M.C. Escher
|
||||||
|
A cup of coffee by Picasso
|
||||||
|
The US Capitol Building in the style of Kandinsky
|
||||||
|
A Mysterious Orb by Andy Warhol
|
||||||
|
The everlasting zero, a glimpse of a million, by Salvador Dali
|
||||||
|
a painting of a haunted house with Halloween decorations by Giovanni Paolo Panini
|
||||||
|
a painting of drops of Venus by Vincent van Gogh
|
||||||
|
ascii art of a man riding a bicycle
|
||||||
|
cyberpunk noir art deco detective in space
|
||||||
|
a cyborg angel in the style of ukiyo-e
|
||||||
|
Hell in the style of pointillism
|
||||||
|
Moloch in the style of socialist realism
|
||||||
|
Metaphysics in the style of WPAP
|
||||||
|
advertisement for a psychedelic virtual reality headset, 16 bit sprite pixel art
|
||||||
|
a watercolor painting of a Christmas tree
|
||||||
|
control room monitors televisions screens computers hacker lab, concept art, matte painting, trending on artstation
|
||||||
|
a group of surgeons wait to cryonically suspend a patient
|
||||||
|
technological singularity cult by James Gurney
|
||||||
|
an autogyro flying car, trending on artstation
|
||||||
|
illustration of airship zepplins in the skies, trending on artstation
|
||||||
|
watercolor illustration of a martian colony geodesic dome aquaponics farming on the surface, trending on artstation
|
||||||
|
humanity is killed by AI, by James Gurney
|
||||||
|
the Vitruvian Man as a propaganda poster for transhumanism
|
|
@ -4,6 +4,7 @@ import numpy as np
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
@ -12,6 +13,11 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
@ -51,7 +57,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ddim_steps",
|
"--ddim_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=200,
|
default=50,
|
||||||
help="number of ddim sampling steps",
|
help="number of ddim sampling steps",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,8 +97,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_samples",
|
"--n_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=8,
|
||||||
help="how many samples to produce for the given prompt",
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -101,11 +107,35 @@ if __name__ == "__main__":
|
||||||
default=5.0,
|
default=5.0,
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dyn",
|
||||||
|
type=float,
|
||||||
|
help="dynamic thresholding from Imagen, in latent space (TODO: try in pixel space with intermediate decode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="logs/f8-kl-clip-encoder-256x256-run1/configs/2022-06-01T22-11-40-project.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
@ -118,21 +148,33 @@ if __name__ == "__main__":
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
outpath = opt.outdir
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
if not opt.from_file:
|
||||||
prompt = opt.prompt
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
sample_path = os.path.join(outpath, "samples")
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
|
||||||
all_samples=list()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
all_samples = list()
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
uc = None
|
uc = None
|
||||||
if opt.scale != 1.0:
|
if opt.scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(opt.n_samples * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
c = model.get_learned_conditioning(prompts)
|
||||||
c = model.get_learned_conditioning(opt.n_samples * [prompt])
|
|
||||||
shape = [4, opt.H//8, opt.W//8]
|
shape = [4, opt.H//8, opt.W//8]
|
||||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
|
@ -141,14 +183,15 @@ if __name__ == "__main__":
|
||||||
verbose=False,
|
verbose=False,
|
||||||
unconditional_guidance_scale=opt.scale,
|
unconditional_guidance_scale=opt.scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=opt.ddim_eta)
|
eta=opt.ddim_eta,
|
||||||
|
dynamic_threshold=opt.dyn)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
for x_sample in x_samples_ddim:
|
for x_sample in x_samples_ddim:
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
|
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
base_count += 1
|
base_count += 1
|
||||||
all_samples.append(x_samples_ddim)
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
@ -160,6 +203,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# to image
|
# to image
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.")
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
||||||
|
|
542
scripts/vqgan_codebook_visualizer.py
Normal file
542
scripts/vqgan_codebook_visualizer.py
Normal file
|
@ -0,0 +1,542 @@
|
||||||
|
import argparse, os, sys, glob
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import streamlit as st
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from main import instantiate_from_config, DataModuleFromConfig
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.dataloader import default_collate
|
||||||
|
from einops import rearrange
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
|
||||||
|
rescale = lambda x: (x + 1.) / 2.
|
||||||
|
|
||||||
|
|
||||||
|
def bchw_to_st(x):
|
||||||
|
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
|
||||||
|
|
||||||
|
def chw_to_st(x):
|
||||||
|
return rescale(x.detach().cpu().numpy().transpose(1,2,0))
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_input(batch, key):
|
||||||
|
inputs = batch[key].permute(0, 3, 1, 2)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def vq_no_codebook_forward(model, x):
|
||||||
|
h = model.encoder(x)
|
||||||
|
h = model.quant_conv(h)
|
||||||
|
h = model.post_quant_conv(h)
|
||||||
|
xrec = model.decoder(h)
|
||||||
|
return xrec
|
||||||
|
|
||||||
|
|
||||||
|
def save_img(x, fname):
|
||||||
|
I = (x.clip(0, 1) * 255).astype(np.uint8)
|
||||||
|
Image.fromarray(I).save(fname)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-r",
|
||||||
|
"--resume",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="load from logdir or checkpoint in logdir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--base",
|
||||||
|
nargs="*",
|
||||||
|
metavar="base_config.yaml",
|
||||||
|
help="paths to base configs. Loaded from left-to-right. "
|
||||||
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||||
|
default=list(),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--config",
|
||||||
|
nargs="?",
|
||||||
|
metavar="single_config.yaml",
|
||||||
|
help="path to single config. If specified, base configs will be ignored "
|
||||||
|
"(except for the last one if left unspecified).",
|
||||||
|
const=True,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_config",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="",
|
||||||
|
help="path to dataset config"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
||||||
|
model = instantiate_from_config(config)
|
||||||
|
if sd is not None:
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
if gpu:
|
||||||
|
model.cuda()
|
||||||
|
if eval_mode:
|
||||||
|
model.eval()
|
||||||
|
return {"model": model}
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(config):
|
||||||
|
# get data
|
||||||
|
data = instantiate_from_config(config.data)
|
||||||
|
data.prepare_data()
|
||||||
|
data.setup()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache(allow_output_mutation=True)
|
||||||
|
def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False):
|
||||||
|
# get data
|
||||||
|
if delete_dataset_params:
|
||||||
|
st.info("Deleting dataset parameters.")
|
||||||
|
del config["data"]["params"]["train"]["params"]
|
||||||
|
del config["data"]["params"]["validation"]["params"]
|
||||||
|
|
||||||
|
dsets = get_data(config) # calls data.config ...
|
||||||
|
|
||||||
|
# now load the specified checkpoint
|
||||||
|
if ckpt:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
global_step = pl_sd["global_step"]
|
||||||
|
else:
|
||||||
|
pl_sd = {"state_dict": None}
|
||||||
|
global_step = None
|
||||||
|
model = load_model_from_config(config.model,
|
||||||
|
pl_sd["state_dict"],
|
||||||
|
gpu=gpu,
|
||||||
|
eval_mode=eval_mode)["model"]
|
||||||
|
return dsets, model, global_step
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_image_embeddings(model, dset, used_codebook, used_indices):
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4)
|
||||||
|
start_index = st.number_input("Start index", value=0,
|
||||||
|
min_value=0,
|
||||||
|
max_value=len(dset) - batch_size)
|
||||||
|
if st.sidebar.button("Sample Batch"):
|
||||||
|
indices = np.random.choice(len(dset), batch_size)
|
||||||
|
else:
|
||||||
|
indices = list(range(start_index, start_index + batch_size))
|
||||||
|
|
||||||
|
st.write(f"Indices: {indices}")
|
||||||
|
batch = default_collate([dset[i] for i in indices])
|
||||||
|
x = model.get_input(batch, "image")
|
||||||
|
x = x.to(model.device)
|
||||||
|
|
||||||
|
# get reconstruction from non-quantized and quantized, compare
|
||||||
|
z_pre_quant = model.encode_to_prequant(x)
|
||||||
|
z_quant, emb_loss, info = model.quantize(z_pre_quant)
|
||||||
|
indices = info[2].detach().cpu().numpy()
|
||||||
|
#indices = rearrange(indices, '(b d) -> b d', b=batch_size)
|
||||||
|
unique_indices = np.unique(indices)
|
||||||
|
st.write(f"Unique indices in batch: {unique_indices.shape[0]}")
|
||||||
|
|
||||||
|
x1 = used_codebook[:, 0].cpu().numpy()
|
||||||
|
x2 = used_codebook[:, 1].cpu().numpy()
|
||||||
|
x3 = used_codebook[:, 2].cpu().numpy()
|
||||||
|
|
||||||
|
zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
|
||||||
|
zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
|
||||||
|
zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
|
||||||
|
|
||||||
|
zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
|
||||||
|
zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
|
||||||
|
zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
|
||||||
|
|
||||||
|
fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1.,
|
||||||
|
color="Blue")),
|
||||||
|
name="All Used Codebook Entries"),
|
||||||
|
|
||||||
|
])
|
||||||
|
trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)),
|
||||||
|
name="Pre-Quant Codes")
|
||||||
|
trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers',
|
||||||
|
marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes")
|
||||||
|
|
||||||
|
fig.add_trace(trace2)
|
||||||
|
fig.add_trace(trace3)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
autosize=False,
|
||||||
|
width=1000,
|
||||||
|
height=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_rec_no_quant = model.decode(z_pre_quant)
|
||||||
|
x_rec_quant = model.decode(z_quant)
|
||||||
|
delta_x = x_rec_no_quant - x_rec_quant
|
||||||
|
|
||||||
|
st.text("Fitting Gaussian...")
|
||||||
|
h, w = z_quant.shape[2], z_quant.shape[3]
|
||||||
|
from sklearn.mixture import GaussianMixture
|
||||||
|
gaussian = GaussianMixture(n_components=1)
|
||||||
|
gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy())
|
||||||
|
samples, _ = gaussian.sample(n_samples=batch_size*h*w)
|
||||||
|
samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3)
|
||||||
|
samples = rearrange(samples, 'b h w c -> b c h w')
|
||||||
|
samples = torch.tensor(samples).to(z_quant)
|
||||||
|
samples, _, _ = model.quantize(samples)
|
||||||
|
x_sample = model.decode(samples)
|
||||||
|
|
||||||
|
all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W
|
||||||
|
|
||||||
|
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
|
||||||
|
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
|
||||||
|
grid = make_grid(all_img, nrow=5)
|
||||||
|
|
||||||
|
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
|
||||||
|
st.image(chw_to_st(grid), clamp=True, output_format="PNG")
|
||||||
|
|
||||||
|
st.write(fig)
|
||||||
|
# 2d projections
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
pairs = [(1, 0), (2, 0), (2, 1)]
|
||||||
|
fig2, ax = plt.subplots(1, 3, figsize=(21, 7))
|
||||||
|
for d in range(3):
|
||||||
|
d1, d2 = pairs[d]
|
||||||
|
#ax[d].scatter(used_codebook[:, d1].cpu().numpy(),
|
||||||
|
# used_codebook[:, d2].cpu().numpy(),
|
||||||
|
# label="All Used Codebook Entries", s=10.0, c=used_indices)
|
||||||
|
ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
|
||||||
|
rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
|
||||||
|
label="Quantized Codes", alpha=0.9, s=8.0, c=indices)
|
||||||
|
ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
|
||||||
|
rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
|
||||||
|
label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices)
|
||||||
|
ax[d].set_title(f"dim {d2} vs dim {d1}")
|
||||||
|
ax[d].legend()
|
||||||
|
st.write(fig2)
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# from scipy.spatial import Voronoi, voronoi_plot_2d
|
||||||
|
# fig3 = plt.figure(figsize=(10,10))
|
||||||
|
# points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy()
|
||||||
|
# vor = Voronoi(points)
|
||||||
|
# # plot
|
||||||
|
# voronoi_plot_2d(vor)
|
||||||
|
# # colorize
|
||||||
|
# for region in vor.regions:
|
||||||
|
# if not -1 in region:
|
||||||
|
# polygon = [vor.vertices[i] for i in region]
|
||||||
|
# plt.fill(*zip(*polygon))
|
||||||
|
# plt.savefig("voronoi_test.png")
|
||||||
|
# st.write(fig3)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_used_indices(model, dset, batch_size=20):
|
||||||
|
dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False)
|
||||||
|
data = list()
|
||||||
|
info = st.empty()
|
||||||
|
for i, batch in enumerate(dloader):
|
||||||
|
x = model.get_input(batch, "image")
|
||||||
|
x = x.to(model.device)
|
||||||
|
zq, _, zi = model.encode(x)
|
||||||
|
indices = zi[2]
|
||||||
|
indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy()
|
||||||
|
data.append(indices)
|
||||||
|
|
||||||
|
unique = np.unique(data)
|
||||||
|
info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}")
|
||||||
|
|
||||||
|
unique = np.unique(data)
|
||||||
|
#np.save(outpath, unique)
|
||||||
|
st.write(f"end of data: found **{unique.size} unique indices.**")
|
||||||
|
print(f"end of data: found {unique.size} unique indices.")
|
||||||
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
def visualize3d(codebook, used_indices):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
|
||||||
|
codebook = codebook.cpu().numpy()
|
||||||
|
selected_codebook = codebook[used_indices, :]
|
||||||
|
z_dim = codebook.shape[1]
|
||||||
|
assert z_dim == 3
|
||||||
|
|
||||||
|
pairs = [(1,0), (2,0), (2,1)]
|
||||||
|
fig, ax = plt.subplots(1,3, figsize=(15,5))
|
||||||
|
for d in range(3):
|
||||||
|
d1, d2 = pairs[d]
|
||||||
|
ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2])
|
||||||
|
ax[d].set_title(f"dim {d2} vs dim {d1}")
|
||||||
|
st.write(fig)
|
||||||
|
|
||||||
|
# # plot 3D
|
||||||
|
# fig = plt.figure(1)
|
||||||
|
# ax = Axes3D(fig)
|
||||||
|
# ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries")
|
||||||
|
# ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries")
|
||||||
|
# plt.legend()
|
||||||
|
# #st.write(fig)
|
||||||
|
# st.pyplot(fig)
|
||||||
|
|
||||||
|
# plot histogram of vector norms
|
||||||
|
fig = plt.figure(2, figsize=(6,5))
|
||||||
|
norms = np.linalg.norm(selected_codebook, axis=1)
|
||||||
|
plt.hist(norms, bins=100, edgecolor="black", lw=1.1)
|
||||||
|
plt.title("Distribution of norms of used codebook entries")
|
||||||
|
st.write(fig)
|
||||||
|
|
||||||
|
# plot 3D with plotly
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
x = selected_codebook[:, 0]
|
||||||
|
y = selected_codebook[:, 1]
|
||||||
|
z = selected_codebook[:, 2]
|
||||||
|
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1.,
|
||||||
|
color="Blue"))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
autosize=False,
|
||||||
|
width=1000,
|
||||||
|
height=1000,
|
||||||
|
)
|
||||||
|
st.write(fig)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_fixed_points(model, dset):
|
||||||
|
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
|
||||||
|
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
|
||||||
|
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
|
||||||
|
clip_decoded = st.checkbox("Clip decoded image", False)
|
||||||
|
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
|
||||||
|
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
|
||||||
|
if st.sidebar.button("Sample Batch"):
|
||||||
|
indices = np.random.choice(len(dset), batch_size)
|
||||||
|
else:
|
||||||
|
indices = list(range(start_index, start_index + batch_size))
|
||||||
|
|
||||||
|
st.write(f"Indices: {indices}")
|
||||||
|
batch = default_collate([dset[i] for i in indices])
|
||||||
|
x = model.get_input(batch, "image")
|
||||||
|
x = x.to(model.device)
|
||||||
|
|
||||||
|
progress = st.empty()
|
||||||
|
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
|
||||||
|
|
||||||
|
image_progress = st.empty() # TODO
|
||||||
|
|
||||||
|
input = x
|
||||||
|
img_quant = x
|
||||||
|
img_noquant = x
|
||||||
|
delta_img = img_quant - img_noquant
|
||||||
|
|
||||||
|
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
|
||||||
|
|
||||||
|
def display(input, img_quant, img_noquant, delta_img):
|
||||||
|
all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W
|
||||||
|
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
|
||||||
|
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
|
||||||
|
grid = make_grid(all_img, nrow=4)
|
||||||
|
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
|
||||||
|
|
||||||
|
display(input, img_quant, img_noquant, delta_img)
|
||||||
|
for n in range(n_iter):
|
||||||
|
# get reconstruction from non-quantized and quantized, compare via iteration
|
||||||
|
|
||||||
|
# quantized_stream
|
||||||
|
z_pre_quant = model.encode_to_prequant(img_quant)
|
||||||
|
z_quant, emb_loss, info = model.quantize(z_pre_quant)
|
||||||
|
|
||||||
|
# non_quantized stream
|
||||||
|
z_noquant = model.encode_to_prequant(img_noquant)
|
||||||
|
|
||||||
|
img_quant = model.decode(z_quant)
|
||||||
|
img_noquant = model.decode(z_noquant)
|
||||||
|
if clip_decoded:
|
||||||
|
img_quant = torch.clamp(img_quant, -1., 1.)
|
||||||
|
img_noquant = torch.clamp(img_noquant, -1., 1.)
|
||||||
|
if quantize_decoded:
|
||||||
|
device = img_quant.device
|
||||||
|
img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
|
||||||
|
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
|
||||||
|
delta_img = img_quant - img_noquant
|
||||||
|
display(input, img_quant, img_noquant, delta_img)
|
||||||
|
progress_cb(n + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_fixed_points_kl_ae(model, dset):
|
||||||
|
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
|
||||||
|
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
|
||||||
|
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
|
||||||
|
clip_decoded = st.checkbox("Clip decoded image", False)
|
||||||
|
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
|
||||||
|
sample_posterior = st.checkbox("Sample from encoder posterior", False)
|
||||||
|
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
|
||||||
|
if st.sidebar.button("Sample Batch"):
|
||||||
|
indices = np.random.choice(len(dset), batch_size)
|
||||||
|
else:
|
||||||
|
indices = list(range(start_index, start_index + batch_size))
|
||||||
|
|
||||||
|
st.write(f"Indices: {indices}")
|
||||||
|
batch = default_collate([dset[i] for i in indices])
|
||||||
|
x = model.get_input(batch, "image")
|
||||||
|
x = x.to(model.device)
|
||||||
|
|
||||||
|
progress = st.empty()
|
||||||
|
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
|
||||||
|
|
||||||
|
st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **")
|
||||||
|
image_progress = st.empty()
|
||||||
|
|
||||||
|
input = x
|
||||||
|
img_noquant = x
|
||||||
|
delta_img = input - img_noquant
|
||||||
|
|
||||||
|
def display(input, img_noquant, delta_img):
|
||||||
|
all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W
|
||||||
|
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
|
||||||
|
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
|
||||||
|
grid = make_grid(all_img, nrow=3)
|
||||||
|
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
distribution_progress = st.empty()
|
||||||
|
def display_latent_distribution(latent_z, alpha=1., title=""):
|
||||||
|
flatz = latent_z.reshape(-1).cpu().detach().numpy()
|
||||||
|
#fig, ax = plt.subplots()
|
||||||
|
ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black")
|
||||||
|
ax.set_title(title)
|
||||||
|
distribution_progress.pyplot(fig)
|
||||||
|
|
||||||
|
display(input, img_noquant, delta_img)
|
||||||
|
for n in range(n_iter):
|
||||||
|
# get reconstructions
|
||||||
|
|
||||||
|
posterior = model.encode(img_noquant)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample()
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
|
||||||
|
if n==0:
|
||||||
|
flatz_init = z.reshape(-1).cpu().detach().numpy()
|
||||||
|
std_init = flatz_init.std()
|
||||||
|
max_init, min_init = flatz_init.max(), flatz_init.min()
|
||||||
|
|
||||||
|
display_latent_distribution(z, alpha=np.sqrt(1/(n+1)),
|
||||||
|
title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}")
|
||||||
|
|
||||||
|
img_noquant = model.decode(z)
|
||||||
|
if clip_decoded:
|
||||||
|
img_noquant = torch.clamp(img_noquant, -1., 1.)
|
||||||
|
if quantize_decoded:
|
||||||
|
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device)
|
||||||
|
delta_img = img_noquant - input
|
||||||
|
display(input, img_noquant, delta_img)
|
||||||
|
progress_cb(n + 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from ldm.models.autoencoder import AutoencoderKL
|
||||||
|
# VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
parser = get_parser()
|
||||||
|
opt, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
ckpt = None
|
||||||
|
if opt.resume:
|
||||||
|
if not os.path.exists(opt.resume):
|
||||||
|
raise ValueError("Cannot find {}".format(opt.resume))
|
||||||
|
if os.path.isfile(opt.resume):
|
||||||
|
paths = opt.resume.split("/")
|
||||||
|
try:
|
||||||
|
idx = len(paths)-paths[::-1].index("logs")+1
|
||||||
|
except ValueError:
|
||||||
|
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
||||||
|
logdir = "/".join(paths[:idx])
|
||||||
|
ckpt = opt.resume
|
||||||
|
else:
|
||||||
|
assert os.path.isdir(opt.resume), opt.resume
|
||||||
|
logdir = opt.resume.rstrip("/")
|
||||||
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||||
|
|
||||||
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
||||||
|
opt.base = base_configs+opt.base
|
||||||
|
|
||||||
|
if opt.config:
|
||||||
|
if type(opt.config) == str:
|
||||||
|
opt.base = [opt.config]
|
||||||
|
else:
|
||||||
|
opt.base = [opt.base[-1]]
|
||||||
|
|
||||||
|
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||||
|
cli = OmegaConf.from_dotlist(unknown)
|
||||||
|
config = OmegaConf.merge(*configs, cli)
|
||||||
|
|
||||||
|
if opt.dataset_config:
|
||||||
|
dcfg = OmegaConf.load(opt.dataset_config)
|
||||||
|
print("Replacing data config with:")
|
||||||
|
print(dcfg.pretty())
|
||||||
|
dcfg = OmegaConf.to_container(dcfg)
|
||||||
|
config["data"] = dcfg["data"]
|
||||||
|
|
||||||
|
st.sidebar.text(ckpt)
|
||||||
|
gs = st.sidebar.empty()
|
||||||
|
gs.text(f"Global step: ?")
|
||||||
|
st.sidebar.text("Options")
|
||||||
|
gpu = st.sidebar.checkbox("GPU", value=True)
|
||||||
|
eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
|
||||||
|
show_config = st.sidebar.checkbox("Show Config", value=False)
|
||||||
|
if show_config:
|
||||||
|
st.info("Checkpoint: {}".format(ckpt))
|
||||||
|
st.json(OmegaConf.to_container(config))
|
||||||
|
|
||||||
|
delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.")
|
||||||
|
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode,
|
||||||
|
delete_dataset_params=delelete_dataset_parameters)
|
||||||
|
gs.text(f"Global step: {global_step}")
|
||||||
|
|
||||||
|
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1])
|
||||||
|
dset = dsets.datasets[split]
|
||||||
|
|
||||||
|
batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20)
|
||||||
|
num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5)
|
||||||
|
data_size = batch_size*num_batches
|
||||||
|
dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False))
|
||||||
|
|
||||||
|
if not isinstance(model, AutoencoderKL):
|
||||||
|
# VQ MODEL
|
||||||
|
codebook = model.quantize.embedding.weight.data
|
||||||
|
st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**")
|
||||||
|
st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**")
|
||||||
|
st.write("**Select ONE of the following options**")
|
||||||
|
|
||||||
|
if st.checkbox("Show Codebook Statistics", False):
|
||||||
|
used_indices = get_used_indices(model, dset, batch_size=batch_size)
|
||||||
|
visualize3d(codebook, used_indices)
|
||||||
|
if st.checkbox("Show Batch Encodings", False):
|
||||||
|
used_indices = get_used_indices(model, dset, batch_size=batch_size)
|
||||||
|
get_image_embeddings(model, dset, codebook[used_indices, :], used_indices)
|
||||||
|
if st.checkbox("Show Fixed Points of Data", False):
|
||||||
|
get_fixed_points(model, dset)
|
||||||
|
|
||||||
|
else:
|
||||||
|
st.info("Detected a KL model")
|
||||||
|
get_fixed_points_kl_ae(model, dset)
|
Loading…
Reference in a new issue