diff --git a/scripts/vqgan_codebook_visualizer.py b/scripts/vqgan_codebook_visualizer.py new file mode 100644 index 0000000..24e4986 --- /dev/null +++ b/scripts/vqgan_codebook_visualizer.py @@ -0,0 +1,542 @@ +import argparse, os, sys, glob +import torch +import numpy as np +import streamlit as st +import matplotlib.pyplot as plt +from omegaconf import OmegaConf +from PIL import Image +from main import instantiate_from_config, DataModuleFromConfig +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from einops import rearrange +from torchvision.utils import make_grid + + +rescale = lambda x: (x + 1.) / 2. + + +def bchw_to_st(x): + return rescale(x.detach().cpu().numpy().transpose(0,2,3,1)) + +def chw_to_st(x): + return rescale(x.detach().cpu().numpy().transpose(1,2,0)) + + +def custom_get_input(batch, key): + inputs = batch[key].permute(0, 3, 1, 2) + return inputs + + +def vq_no_codebook_forward(model, x): + h = model.encoder(x) + h = model.quant_conv(h) + h = model.post_quant_conv(h) + xrec = model.decoder(h) + return xrec + + +def save_img(x, fname): + I = (x.clip(0, 1) * 255).astype(np.uint8) + Image.fromarray(I).save(fname) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-c", + "--config", + nargs="?", + metavar="single_config.yaml", + help="path to single config. If specified, base configs will be ignored " + "(except for the last one if left unspecified).", + const=True, + default="", + ) + parser.add_argument( + "--dataset_config", + type=str, + nargs="?", + default="", + help="path to dataset config" + ) + return parser + + +def load_model_from_config(config, sd, gpu=True, eval_mode=True): + model = instantiate_from_config(config) + if sd is not None: + model.load_state_dict(sd) + if gpu: + model.cuda() + if eval_mode: + model.eval() + return {"model": model} + + +def get_data(config): + # get data + data = instantiate_from_config(config.data) + data.prepare_data() + data.setup() + return data + + +@st.cache(allow_output_mutation=True) +def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False): + # get data + if delete_dataset_params: + st.info("Deleting dataset parameters.") + del config["data"]["params"]["train"]["params"] + del config["data"]["params"]["validation"]["params"] + + dsets = get_data(config) # calls data.config ... + + # now load the specified checkpoint + if ckpt: + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"], + gpu=gpu, + eval_mode=eval_mode)["model"] + return dsets, model, global_step + + +@torch.no_grad() +def get_image_embeddings(model, dset, used_codebook, used_indices): + import plotly.graph_objects as go + batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4) + start_index = st.number_input("Start index", value=0, + min_value=0, + max_value=len(dset) - batch_size) + if st.sidebar.button("Sample Batch"): + indices = np.random.choice(len(dset), batch_size) + else: + indices = list(range(start_index, start_index + batch_size)) + + st.write(f"Indices: {indices}") + batch = default_collate([dset[i] for i in indices]) + x = model.get_input(batch, "image") + x = x.to(model.device) + + # get reconstruction from non-quantized and quantized, compare + z_pre_quant = model.encode_to_prequant(x) + z_quant, emb_loss, info = model.quantize(z_pre_quant) + indices = info[2].detach().cpu().numpy() + #indices = rearrange(indices, '(b d) -> b d', b=batch_size) + unique_indices = np.unique(indices) + st.write(f"Unique indices in batch: {unique_indices.shape[0]}") + + x1 = used_codebook[:, 0].cpu().numpy() + x2 = used_codebook[:, 1].cpu().numpy() + x3 = used_codebook[:, 2].cpu().numpy() + + zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy() + zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy() + zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy() + + zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy() + zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy() + zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy() + + fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1., + color="Blue")), + name="All Used Codebook Entries"), + + ]) + trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)), + name="Pre-Quant Codes") + trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers', + marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes") + + fig.add_trace(trace2) + fig.add_trace(trace3) + + fig.update_layout( + autosize=False, + width=1000, + height=1000, + ) + + x_rec_no_quant = model.decode(z_pre_quant) + x_rec_quant = model.decode(z_quant) + delta_x = x_rec_no_quant - x_rec_quant + + st.text("Fitting Gaussian...") + h, w = z_quant.shape[2], z_quant.shape[3] + from sklearn.mixture import GaussianMixture + gaussian = GaussianMixture(n_components=1) + gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy()) + samples, _ = gaussian.sample(n_samples=batch_size*h*w) + samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3) + samples = rearrange(samples, 'b h w c -> b c h w') + samples = torch.tensor(samples).to(z_quant) + samples, _, _ = model.quantize(samples) + x_sample = model.decode(samples) + + all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W + + all_img = rearrange(all_img, 'n b c h w -> b n c h w') + all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') + grid = make_grid(all_img, nrow=5) + + st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **") + st.image(chw_to_st(grid), clamp=True, output_format="PNG") + + st.write(fig) + # 2d projections + import matplotlib.pyplot as plt + pairs = [(1, 0), (2, 0), (2, 1)] + fig2, ax = plt.subplots(1, 3, figsize=(21, 7)) + for d in range(3): + d1, d2 = pairs[d] + #ax[d].scatter(used_codebook[:, d1].cpu().numpy(), + # used_codebook[:, d2].cpu().numpy(), + # label="All Used Codebook Entries", s=10.0, c=used_indices) + ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(), + rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(), + label="Quantized Codes", alpha=0.9, s=8.0, c=indices) + ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(), + rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(), + label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices) + ax[d].set_title(f"dim {d2} vs dim {d1}") + ax[d].legend() + st.write(fig2) + + plt.close() + + # from scipy.spatial import Voronoi, voronoi_plot_2d + # fig3 = plt.figure(figsize=(10,10)) + # points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy() + # vor = Voronoi(points) + # # plot + # voronoi_plot_2d(vor) + # # colorize + # for region in vor.regions: + # if not -1 in region: + # polygon = [vor.vertices[i] for i in region] + # plt.fill(*zip(*polygon)) + # plt.savefig("voronoi_test.png") + # st.write(fig3) + + +@torch.no_grad() +def get_used_indices(model, dset, batch_size=20): + dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False) + data = list() + info = st.empty() + for i, batch in enumerate(dloader): + x = model.get_input(batch, "image") + x = x.to(model.device) + zq, _, zi = model.encode(x) + indices = zi[2] + indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy() + data.append(indices) + + unique = np.unique(data) + info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}") + + unique = np.unique(data) + #np.save(outpath, unique) + st.write(f"end of data: found **{unique.size} unique indices.**") + print(f"end of data: found {unique.size} unique indices.") + return unique + + +def visualize3d(codebook, used_indices): + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + codebook = codebook.cpu().numpy() + selected_codebook = codebook[used_indices, :] + z_dim = codebook.shape[1] + assert z_dim == 3 + + pairs = [(1,0), (2,0), (2,1)] + fig, ax = plt.subplots(1,3, figsize=(15,5)) + for d in range(3): + d1, d2 = pairs[d] + ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2]) + ax[d].set_title(f"dim {d2} vs dim {d1}") + st.write(fig) + + # # plot 3D + # fig = plt.figure(1) + # ax = Axes3D(fig) + # ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries") + # ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries") + # plt.legend() + # #st.write(fig) + # st.pyplot(fig) + + # plot histogram of vector norms + fig = plt.figure(2, figsize=(6,5)) + norms = np.linalg.norm(selected_codebook, axis=1) + plt.hist(norms, bins=100, edgecolor="black", lw=1.1) + plt.title("Distribution of norms of used codebook entries") + st.write(fig) + + # plot 3D with plotly + import pandas as pd + import plotly.graph_objects as go + x = selected_codebook[:, 0] + y = selected_codebook[:, 1] + z = selected_codebook[:, 2] + fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1., + color="Blue")) + ) + ] + ) + + fig.update_layout( + autosize=False, + width=1000, + height=1000, + ) + st.write(fig) + + +@torch.no_grad() +def get_fixed_points(model, dset): + n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25) + batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4) + start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) + clip_decoded = st.checkbox("Clip decoded image", False) + quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False) + factor = st.sidebar.number_input("image size", value=1., min_value=0.1) + if st.sidebar.button("Sample Batch"): + indices = np.random.choice(len(dset), batch_size) + else: + indices = list(range(start_index, start_index + batch_size)) + + st.write(f"Indices: {indices}") + batch = default_collate([dset[i] for i in indices]) + x = model.get_input(batch, "image") + x = x.to(model.device) + + progress = st.empty() + progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}") + + image_progress = st.empty() # TODO + + input = x + img_quant = x + img_noquant = x + delta_img = img_quant - img_noquant + + st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **") + + def display(input, img_quant, img_noquant, delta_img): + all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W + all_img = rearrange(all_img, 'n b c h w -> b n c h w') + all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') + grid = make_grid(all_img, nrow=4) + image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2])) + + display(input, img_quant, img_noquant, delta_img) + for n in range(n_iter): + # get reconstruction from non-quantized and quantized, compare via iteration + + # quantized_stream + z_pre_quant = model.encode_to_prequant(img_quant) + z_quant, emb_loss, info = model.quantize(z_pre_quant) + + # non_quantized stream + z_noquant = model.encode_to_prequant(img_noquant) + + img_quant = model.decode(z_quant) + img_noquant = model.decode(z_noquant) + if clip_decoded: + img_quant = torch.clamp(img_quant, -1., 1.) + img_noquant = torch.clamp(img_noquant, -1., 1.) + if quantize_decoded: + device = img_quant.device + img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device) + img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device) + delta_img = img_quant - img_noquant + display(input, img_quant, img_noquant, delta_img) + progress_cb(n + 1) + + +@torch.no_grad() +def get_fixed_points_kl_ae(model, dset): + n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25) + batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4) + start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) + clip_decoded = st.checkbox("Clip decoded image", False) + quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False) + sample_posterior = st.checkbox("Sample from encoder posterior", False) + factor = st.sidebar.number_input("image size", value=1., min_value=0.1) + if st.sidebar.button("Sample Batch"): + indices = np.random.choice(len(dset), batch_size) + else: + indices = list(range(start_index, start_index + batch_size)) + + st.write(f"Indices: {indices}") + batch = default_collate([dset[i] for i in indices]) + x = model.get_input(batch, "image") + x = x.to(model.device) + + progress = st.empty() + progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}") + + st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **") + image_progress = st.empty() + + input = x + img_noquant = x + delta_img = input - img_noquant + + def display(input, img_noquant, delta_img): + all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W + all_img = rearrange(all_img, 'n b c h w -> b n c h w') + all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') + grid = make_grid(all_img, nrow=3) + image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2])) + + fig, ax = plt.subplots() + + distribution_progress = st.empty() + def display_latent_distribution(latent_z, alpha=1., title=""): + flatz = latent_z.reshape(-1).cpu().detach().numpy() + #fig, ax = plt.subplots() + ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black") + ax.set_title(title) + distribution_progress.pyplot(fig) + + display(input, img_noquant, delta_img) + for n in range(n_iter): + # get reconstructions + + posterior = model.encode(img_noquant) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + if n==0: + flatz_init = z.reshape(-1).cpu().detach().numpy() + std_init = flatz_init.std() + max_init, min_init = flatz_init.max(), flatz_init.min() + + display_latent_distribution(z, alpha=np.sqrt(1/(n+1)), + title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}") + + img_noquant = model.decode(z) + if clip_decoded: + img_noquant = torch.clamp(img_noquant, -1., 1.) + if quantize_decoded: + img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device) + delta_img = img_noquant - input + display(input, img_noquant, delta_img) + progress_cb(n + 1) + + + +if __name__ == "__main__": + from ldm.models.autoencoder import AutoencoderKL + # VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL + sys.path.append(os.getcwd()) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + + ckpt = None + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + try: + idx = len(paths)-paths[::-1].index("logs")+1 + except ValueError: + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) + opt.base = base_configs+opt.base + + if opt.config: + if type(opt.config) == str: + opt.base = [opt.config] + else: + opt.base = [opt.base[-1]] + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + if opt.dataset_config: + dcfg = OmegaConf.load(opt.dataset_config) + print("Replacing data config with:") + print(dcfg.pretty()) + dcfg = OmegaConf.to_container(dcfg) + config["data"] = dcfg["data"] + + st.sidebar.text(ckpt) + gs = st.sidebar.empty() + gs.text(f"Global step: ?") + st.sidebar.text("Options") + gpu = st.sidebar.checkbox("GPU", value=True) + eval_mode = st.sidebar.checkbox("Eval Mode", value=True) + show_config = st.sidebar.checkbox("Show Config", value=False) + if show_config: + st.info("Checkpoint: {}".format(ckpt)) + st.json(OmegaConf.to_container(config)) + + delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.") + dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode, + delete_dataset_params=delelete_dataset_parameters) + gs.text(f"Global step: {global_step}") + + split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1]) + dset = dsets.datasets[split] + + batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20) + num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5) + data_size = batch_size*num_batches + dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False)) + + if not isinstance(model, AutoencoderKL): + # VQ MODEL + codebook = model.quantize.embedding.weight.data + st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**") + st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**") + st.write("**Select ONE of the following options**") + + if st.checkbox("Show Codebook Statistics", False): + used_indices = get_used_indices(model, dset, batch_size=batch_size) + visualize3d(codebook, used_indices) + if st.checkbox("Show Batch Encodings", False): + used_indices = get_used_indices(model, dset, batch_size=batch_size) + get_image_embeddings(model, dset, codebook[used_indices, :], used_indices) + if st.checkbox("Show Fixed Points of Data", False): + get_fixed_points(model, dset) + + else: + st.info("Detected a KL model") + get_fixed_points_kl_ae(model, dset) \ No newline at end of file