stable-diffusion-finetune/scripts/vqgan_codebook_visualizer.py

542 lines
20 KiB
Python
Raw Normal View History

2022-06-05 17:22:56 +00:00
import argparse, os, sys, glob
import torch
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from einops import rearrange
from torchvision.utils import make_grid
rescale = lambda x: (x + 1.) / 2.
def bchw_to_st(x):
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
def chw_to_st(x):
return rescale(x.detach().cpu().numpy().transpose(1,2,0))
def custom_get_input(batch, key):
inputs = batch[key].permute(0, 3, 1, 2)
return inputs
def vq_no_codebook_forward(model, x):
h = model.encoder(x)
h = model.quant_conv(h)
h = model.post_quant_conv(h)
xrec = model.decoder(h)
return xrec
def save_img(x, fname):
I = (x.clip(0, 1) * 255).astype(np.uint8)
Image.fromarray(I).save(fname)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"--dataset_config",
type=str,
nargs="?",
default="",
help="path to dataset config"
)
return parser
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
if sd is not None:
model.load_state_dict(sd)
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def get_data(config):
# get data
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
return data
@st.cache(allow_output_mutation=True)
def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False):
# get data
if delete_dataset_params:
st.info("Deleting dataset parameters.")
del config["data"]["params"]["train"]["params"]
del config["data"]["params"]["validation"]["params"]
dsets = get_data(config) # calls data.config ...
# now load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"],
gpu=gpu,
eval_mode=eval_mode)["model"]
return dsets, model, global_step
@torch.no_grad()
def get_image_embeddings(model, dset, used_codebook, used_indices):
import plotly.graph_objects as go
batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4)
start_index = st.number_input("Start index", value=0,
min_value=0,
max_value=len(dset) - batch_size)
if st.sidebar.button("Sample Batch"):
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index + batch_size))
st.write(f"Indices: {indices}")
batch = default_collate([dset[i] for i in indices])
x = model.get_input(batch, "image")
x = x.to(model.device)
# get reconstruction from non-quantized and quantized, compare
z_pre_quant = model.encode_to_prequant(x)
z_quant, emb_loss, info = model.quantize(z_pre_quant)
indices = info[2].detach().cpu().numpy()
#indices = rearrange(indices, '(b d) -> b d', b=batch_size)
unique_indices = np.unique(indices)
st.write(f"Unique indices in batch: {unique_indices.shape[0]}")
x1 = used_codebook[:, 0].cpu().numpy()
x2 = used_codebook[:, 1].cpu().numpy()
x3 = used_codebook[:, 2].cpu().numpy()
zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1.,
color="Blue")),
name="All Used Codebook Entries"),
])
trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)),
name="Pre-Quant Codes")
trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers',
marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes")
fig.add_trace(trace2)
fig.add_trace(trace3)
fig.update_layout(
autosize=False,
width=1000,
height=1000,
)
x_rec_no_quant = model.decode(z_pre_quant)
x_rec_quant = model.decode(z_quant)
delta_x = x_rec_no_quant - x_rec_quant
st.text("Fitting Gaussian...")
h, w = z_quant.shape[2], z_quant.shape[3]
from sklearn.mixture import GaussianMixture
gaussian = GaussianMixture(n_components=1)
gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy())
samples, _ = gaussian.sample(n_samples=batch_size*h*w)
samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3)
samples = rearrange(samples, 'b h w c -> b c h w')
samples = torch.tensor(samples).to(z_quant)
samples, _, _ = model.quantize(samples)
x_sample = model.decode(samples)
all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
grid = make_grid(all_img, nrow=5)
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
st.image(chw_to_st(grid), clamp=True, output_format="PNG")
st.write(fig)
# 2d projections
import matplotlib.pyplot as plt
pairs = [(1, 0), (2, 0), (2, 1)]
fig2, ax = plt.subplots(1, 3, figsize=(21, 7))
for d in range(3):
d1, d2 = pairs[d]
#ax[d].scatter(used_codebook[:, d1].cpu().numpy(),
# used_codebook[:, d2].cpu().numpy(),
# label="All Used Codebook Entries", s=10.0, c=used_indices)
ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
label="Quantized Codes", alpha=0.9, s=8.0, c=indices)
ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices)
ax[d].set_title(f"dim {d2} vs dim {d1}")
ax[d].legend()
st.write(fig2)
plt.close()
# from scipy.spatial import Voronoi, voronoi_plot_2d
# fig3 = plt.figure(figsize=(10,10))
# points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy()
# vor = Voronoi(points)
# # plot
# voronoi_plot_2d(vor)
# # colorize
# for region in vor.regions:
# if not -1 in region:
# polygon = [vor.vertices[i] for i in region]
# plt.fill(*zip(*polygon))
# plt.savefig("voronoi_test.png")
# st.write(fig3)
@torch.no_grad()
def get_used_indices(model, dset, batch_size=20):
dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False)
data = list()
info = st.empty()
for i, batch in enumerate(dloader):
x = model.get_input(batch, "image")
x = x.to(model.device)
zq, _, zi = model.encode(x)
indices = zi[2]
indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy()
data.append(indices)
unique = np.unique(data)
info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}")
unique = np.unique(data)
#np.save(outpath, unique)
st.write(f"end of data: found **{unique.size} unique indices.**")
print(f"end of data: found {unique.size} unique indices.")
return unique
def visualize3d(codebook, used_indices):
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
codebook = codebook.cpu().numpy()
selected_codebook = codebook[used_indices, :]
z_dim = codebook.shape[1]
assert z_dim == 3
pairs = [(1,0), (2,0), (2,1)]
fig, ax = plt.subplots(1,3, figsize=(15,5))
for d in range(3):
d1, d2 = pairs[d]
ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2])
ax[d].set_title(f"dim {d2} vs dim {d1}")
st.write(fig)
# # plot 3D
# fig = plt.figure(1)
# ax = Axes3D(fig)
# ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries")
# ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries")
# plt.legend()
# #st.write(fig)
# st.pyplot(fig)
# plot histogram of vector norms
fig = plt.figure(2, figsize=(6,5))
norms = np.linalg.norm(selected_codebook, axis=1)
plt.hist(norms, bins=100, edgecolor="black", lw=1.1)
plt.title("Distribution of norms of used codebook entries")
st.write(fig)
# plot 3D with plotly
import pandas as pd
import plotly.graph_objects as go
x = selected_codebook[:, 0]
y = selected_codebook[:, 1]
z = selected_codebook[:, 2]
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1.,
color="Blue"))
)
]
)
fig.update_layout(
autosize=False,
width=1000,
height=1000,
)
st.write(fig)
@torch.no_grad()
def get_fixed_points(model, dset):
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
clip_decoded = st.checkbox("Clip decoded image", False)
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
if st.sidebar.button("Sample Batch"):
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index + batch_size))
st.write(f"Indices: {indices}")
batch = default_collate([dset[i] for i in indices])
x = model.get_input(batch, "image")
x = x.to(model.device)
progress = st.empty()
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
image_progress = st.empty() # TODO
input = x
img_quant = x
img_noquant = x
delta_img = img_quant - img_noquant
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
def display(input, img_quant, img_noquant, delta_img):
all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
grid = make_grid(all_img, nrow=4)
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
display(input, img_quant, img_noquant, delta_img)
for n in range(n_iter):
# get reconstruction from non-quantized and quantized, compare via iteration
# quantized_stream
z_pre_quant = model.encode_to_prequant(img_quant)
z_quant, emb_loss, info = model.quantize(z_pre_quant)
# non_quantized stream
z_noquant = model.encode_to_prequant(img_noquant)
img_quant = model.decode(z_quant)
img_noquant = model.decode(z_noquant)
if clip_decoded:
img_quant = torch.clamp(img_quant, -1., 1.)
img_noquant = torch.clamp(img_noquant, -1., 1.)
if quantize_decoded:
device = img_quant.device
img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
delta_img = img_quant - img_noquant
display(input, img_quant, img_noquant, delta_img)
progress_cb(n + 1)
@torch.no_grad()
def get_fixed_points_kl_ae(model, dset):
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
clip_decoded = st.checkbox("Clip decoded image", False)
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
sample_posterior = st.checkbox("Sample from encoder posterior", False)
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
if st.sidebar.button("Sample Batch"):
indices = np.random.choice(len(dset), batch_size)
else:
indices = list(range(start_index, start_index + batch_size))
st.write(f"Indices: {indices}")
batch = default_collate([dset[i] for i in indices])
x = model.get_input(batch, "image")
x = x.to(model.device)
progress = st.empty()
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **")
image_progress = st.empty()
input = x
img_noquant = x
delta_img = input - img_noquant
def display(input, img_noquant, delta_img):
all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
grid = make_grid(all_img, nrow=3)
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
fig, ax = plt.subplots()
distribution_progress = st.empty()
def display_latent_distribution(latent_z, alpha=1., title=""):
flatz = latent_z.reshape(-1).cpu().detach().numpy()
#fig, ax = plt.subplots()
ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black")
ax.set_title(title)
distribution_progress.pyplot(fig)
display(input, img_noquant, delta_img)
for n in range(n_iter):
# get reconstructions
posterior = model.encode(img_noquant)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
if n==0:
flatz_init = z.reshape(-1).cpu().detach().numpy()
std_init = flatz_init.std()
max_init, min_init = flatz_init.max(), flatz_init.min()
display_latent_distribution(z, alpha=np.sqrt(1/(n+1)),
title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}")
img_noquant = model.decode(z)
if clip_decoded:
img_noquant = torch.clamp(img_noquant, -1., 1.)
if quantize_decoded:
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device)
delta_img = img_noquant - input
display(input, img_noquant, delta_img)
progress_cb(n + 1)
if __name__ == "__main__":
from ldm.models.autoencoder import AutoencoderKL
# VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
if opt.dataset_config:
dcfg = OmegaConf.load(opt.dataset_config)
print("Replacing data config with:")
print(dcfg.pretty())
dcfg = OmegaConf.to_container(dcfg)
config["data"] = dcfg["data"]
st.sidebar.text(ckpt)
gs = st.sidebar.empty()
gs.text(f"Global step: ?")
st.sidebar.text("Options")
gpu = st.sidebar.checkbox("GPU", value=True)
eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
show_config = st.sidebar.checkbox("Show Config", value=False)
if show_config:
st.info("Checkpoint: {}".format(ckpt))
st.json(OmegaConf.to_container(config))
delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.")
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode,
delete_dataset_params=delelete_dataset_parameters)
gs.text(f"Global step: {global_step}")
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1])
dset = dsets.datasets[split]
batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20)
num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5)
data_size = batch_size*num_batches
dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False))
if not isinstance(model, AutoencoderKL):
# VQ MODEL
codebook = model.quantize.embedding.weight.data
st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**")
st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**")
st.write("**Select ONE of the following options**")
if st.checkbox("Show Codebook Statistics", False):
used_indices = get_used_indices(model, dset, batch_size=batch_size)
visualize3d(codebook, used_indices)
if st.checkbox("Show Batch Encodings", False):
used_indices = get_used_indices(model, dset, batch_size=batch_size)
get_image_embeddings(model, dset, codebook[used_indices, :], used_indices)
if st.checkbox("Show Fixed Points of Data", False):
get_fixed_points(model, dset)
else:
st.info("Detected a KL model")
get_fixed_points_kl_ae(model, dset)