import argparse, os, sys, glob import torch import numpy as np import streamlit as st import matplotlib.pyplot as plt from omegaconf import OmegaConf from PIL import Image from main import instantiate_from_config, DataModuleFromConfig from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from einops import rearrange from torchvision.utils import make_grid rescale = lambda x: (x + 1.) / 2. def bchw_to_st(x): return rescale(x.detach().cpu().numpy().transpose(0,2,3,1)) def chw_to_st(x): return rescale(x.detach().cpu().numpy().transpose(1,2,0)) def custom_get_input(batch, key): inputs = batch[key].permute(0, 3, 1, 2) return inputs def vq_no_codebook_forward(model, x): h = model.encoder(x) h = model.quant_conv(h) h = model.post_quant_conv(h) xrec = model.decoder(h) return xrec def save_img(x, fname): I = (x.clip(0, 1) * 255).astype(np.uint8) Image.fromarray(I).save(fname) def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( "-r", "--resume", type=str, nargs="?", help="load from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-c", "--config", nargs="?", metavar="single_config.yaml", help="path to single config. If specified, base configs will be ignored " "(except for the last one if left unspecified).", const=True, default="", ) parser.add_argument( "--dataset_config", type=str, nargs="?", default="", help="path to dataset config" ) return parser def load_model_from_config(config, sd, gpu=True, eval_mode=True): model = instantiate_from_config(config) if sd is not None: model.load_state_dict(sd) if gpu: model.cuda() if eval_mode: model.eval() return {"model": model} def get_data(config): # get data data = instantiate_from_config(config.data) data.prepare_data() data.setup() return data @st.cache(allow_output_mutation=True) def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False): # get data if delete_dataset_params: st.info("Deleting dataset parameters.") del config["data"]["params"]["train"]["params"] del config["data"]["params"]["validation"]["params"] dsets = get_data(config) # calls data.config ... # now load the specified checkpoint if ckpt: pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] else: pl_sd = {"state_dict": None} global_step = None model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] return dsets, model, global_step @torch.no_grad() def get_image_embeddings(model, dset, used_codebook, used_indices): import plotly.graph_objects as go batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4) start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) if st.sidebar.button("Sample Batch"): indices = np.random.choice(len(dset), batch_size) else: indices = list(range(start_index, start_index + batch_size)) st.write(f"Indices: {indices}") batch = default_collate([dset[i] for i in indices]) x = model.get_input(batch, "image") x = x.to(model.device) # get reconstruction from non-quantized and quantized, compare z_pre_quant = model.encode_to_prequant(x) z_quant, emb_loss, info = model.quantize(z_pre_quant) indices = info[2].detach().cpu().numpy() #indices = rearrange(indices, '(b d) -> b d', b=batch_size) unique_indices = np.unique(indices) st.write(f"Unique indices in batch: {unique_indices.shape[0]}") x1 = used_codebook[:, 0].cpu().numpy() x2 = used_codebook[:, 1].cpu().numpy() x3 = used_codebook[:, 2].cpu().numpy() zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy() zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy() zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy() zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy() zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy() zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy() fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1., color="Blue")), name="All Used Codebook Entries"), ]) trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)), name="Pre-Quant Codes") trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers', marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes") fig.add_trace(trace2) fig.add_trace(trace3) fig.update_layout( autosize=False, width=1000, height=1000, ) x_rec_no_quant = model.decode(z_pre_quant) x_rec_quant = model.decode(z_quant) delta_x = x_rec_no_quant - x_rec_quant st.text("Fitting Gaussian...") h, w = z_quant.shape[2], z_quant.shape[3] from sklearn.mixture import GaussianMixture gaussian = GaussianMixture(n_components=1) gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy()) samples, _ = gaussian.sample(n_samples=batch_size*h*w) samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3) samples = rearrange(samples, 'b h w c -> b c h w') samples = torch.tensor(samples).to(z_quant) samples, _, _ = model.quantize(samples) x_sample = model.decode(samples) all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W all_img = rearrange(all_img, 'n b c h w -> b n c h w') all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') grid = make_grid(all_img, nrow=5) st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **") st.image(chw_to_st(grid), clamp=True, output_format="PNG") st.write(fig) # 2d projections import matplotlib.pyplot as plt pairs = [(1, 0), (2, 0), (2, 1)] fig2, ax = plt.subplots(1, 3, figsize=(21, 7)) for d in range(3): d1, d2 = pairs[d] #ax[d].scatter(used_codebook[:, d1].cpu().numpy(), # used_codebook[:, d2].cpu().numpy(), # label="All Used Codebook Entries", s=10.0, c=used_indices) ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(), rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(), label="Quantized Codes", alpha=0.9, s=8.0, c=indices) ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(), rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(), label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices) ax[d].set_title(f"dim {d2} vs dim {d1}") ax[d].legend() st.write(fig2) plt.close() # from scipy.spatial import Voronoi, voronoi_plot_2d # fig3 = plt.figure(figsize=(10,10)) # points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy() # vor = Voronoi(points) # # plot # voronoi_plot_2d(vor) # # colorize # for region in vor.regions: # if not -1 in region: # polygon = [vor.vertices[i] for i in region] # plt.fill(*zip(*polygon)) # plt.savefig("voronoi_test.png") # st.write(fig3) @torch.no_grad() def get_used_indices(model, dset, batch_size=20): dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False) data = list() info = st.empty() for i, batch in enumerate(dloader): x = model.get_input(batch, "image") x = x.to(model.device) zq, _, zi = model.encode(x) indices = zi[2] indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy() data.append(indices) unique = np.unique(data) info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}") unique = np.unique(data) #np.save(outpath, unique) st.write(f"end of data: found **{unique.size} unique indices.**") print(f"end of data: found {unique.size} unique indices.") return unique def visualize3d(codebook, used_indices): import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D codebook = codebook.cpu().numpy() selected_codebook = codebook[used_indices, :] z_dim = codebook.shape[1] assert z_dim == 3 pairs = [(1,0), (2,0), (2,1)] fig, ax = plt.subplots(1,3, figsize=(15,5)) for d in range(3): d1, d2 = pairs[d] ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2]) ax[d].set_title(f"dim {d2} vs dim {d1}") st.write(fig) # # plot 3D # fig = plt.figure(1) # ax = Axes3D(fig) # ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries") # ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries") # plt.legend() # #st.write(fig) # st.pyplot(fig) # plot histogram of vector norms fig = plt.figure(2, figsize=(6,5)) norms = np.linalg.norm(selected_codebook, axis=1) plt.hist(norms, bins=100, edgecolor="black", lw=1.1) plt.title("Distribution of norms of used codebook entries") st.write(fig) # plot 3D with plotly import pandas as pd import plotly.graph_objects as go x = selected_codebook[:, 0] y = selected_codebook[:, 1] z = selected_codebook[:, 2] fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1., color="Blue")) ) ] ) fig.update_layout( autosize=False, width=1000, height=1000, ) st.write(fig) @torch.no_grad() def get_fixed_points(model, dset): n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25) batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4) start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) clip_decoded = st.checkbox("Clip decoded image", False) quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False) factor = st.sidebar.number_input("image size", value=1., min_value=0.1) if st.sidebar.button("Sample Batch"): indices = np.random.choice(len(dset), batch_size) else: indices = list(range(start_index, start_index + batch_size)) st.write(f"Indices: {indices}") batch = default_collate([dset[i] for i in indices]) x = model.get_input(batch, "image") x = x.to(model.device) progress = st.empty() progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}") image_progress = st.empty() # TODO input = x img_quant = x img_noquant = x delta_img = img_quant - img_noquant st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **") def display(input, img_quant, img_noquant, delta_img): all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W all_img = rearrange(all_img, 'n b c h w -> b n c h w') all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') grid = make_grid(all_img, nrow=4) image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2])) display(input, img_quant, img_noquant, delta_img) for n in range(n_iter): # get reconstruction from non-quantized and quantized, compare via iteration # quantized_stream z_pre_quant = model.encode_to_prequant(img_quant) z_quant, emb_loss, info = model.quantize(z_pre_quant) # non_quantized stream z_noquant = model.encode_to_prequant(img_noquant) img_quant = model.decode(z_quant) img_noquant = model.decode(z_noquant) if clip_decoded: img_quant = torch.clamp(img_quant, -1., 1.) img_noquant = torch.clamp(img_noquant, -1., 1.) if quantize_decoded: device = img_quant.device img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device) img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device) delta_img = img_quant - img_noquant display(input, img_quant, img_noquant, delta_img) progress_cb(n + 1) @torch.no_grad() def get_fixed_points_kl_ae(model, dset): n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25) batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4) start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size) clip_decoded = st.checkbox("Clip decoded image", False) quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False) sample_posterior = st.checkbox("Sample from encoder posterior", False) factor = st.sidebar.number_input("image size", value=1., min_value=0.1) if st.sidebar.button("Sample Batch"): indices = np.random.choice(len(dset), batch_size) else: indices = list(range(start_index, start_index + batch_size)) st.write(f"Indices: {indices}") batch = default_collate([dset[i] for i in indices]) x = model.get_input(batch, "image") x = x.to(model.device) progress = st.empty() progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}") st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **") image_progress = st.empty() input = x img_noquant = x delta_img = input - img_noquant def display(input, img_noquant, delta_img): all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W all_img = rearrange(all_img, 'n b c h w -> b n c h w') all_img = rearrange(all_img, 'b n c h w -> (b n) c h w') grid = make_grid(all_img, nrow=3) image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2])) fig, ax = plt.subplots() distribution_progress = st.empty() def display_latent_distribution(latent_z, alpha=1., title=""): flatz = latent_z.reshape(-1).cpu().detach().numpy() #fig, ax = plt.subplots() ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black") ax.set_title(title) distribution_progress.pyplot(fig) display(input, img_noquant, delta_img) for n in range(n_iter): # get reconstructions posterior = model.encode(img_noquant) if sample_posterior: z = posterior.sample() else: z = posterior.mode() if n==0: flatz_init = z.reshape(-1).cpu().detach().numpy() std_init = flatz_init.std() max_init, min_init = flatz_init.max(), flatz_init.min() display_latent_distribution(z, alpha=np.sqrt(1/(n+1)), title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}") img_noquant = model.decode(z) if clip_decoded: img_noquant = torch.clamp(img_noquant, -1., 1.) if quantize_decoded: img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device) delta_img = img_noquant - input display(input, img_noquant, delta_img) progress_cb(n + 1) if __name__ == "__main__": from ldm.models.autoencoder import AutoencoderKL # VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL sys.path.append(os.getcwd()) parser = get_parser() opt, unknown = parser.parse_known_args() ckpt = None if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") try: idx = len(paths)-paths[::-1].index("logs")+1 except ValueError: idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt logdir = "/".join(paths[:idx]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) opt.base = base_configs+opt.base if opt.config: if type(opt.config) == str: opt.base = [opt.config] else: opt.base = [opt.base[-1]] configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) if opt.dataset_config: dcfg = OmegaConf.load(opt.dataset_config) print("Replacing data config with:") print(dcfg.pretty()) dcfg = OmegaConf.to_container(dcfg) config["data"] = dcfg["data"] st.sidebar.text(ckpt) gs = st.sidebar.empty() gs.text(f"Global step: ?") st.sidebar.text("Options") gpu = st.sidebar.checkbox("GPU", value=True) eval_mode = st.sidebar.checkbox("Eval Mode", value=True) show_config = st.sidebar.checkbox("Show Config", value=False) if show_config: st.info("Checkpoint: {}".format(ckpt)) st.json(OmegaConf.to_container(config)) delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.") dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=delelete_dataset_parameters) gs.text(f"Global step: {global_step}") split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1]) dset = dsets.datasets[split] batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20) num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5) data_size = batch_size*num_batches dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False)) if not isinstance(model, AutoencoderKL): # VQ MODEL codebook = model.quantize.embedding.weight.data st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**") st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**") st.write("**Select ONE of the following options**") if st.checkbox("Show Codebook Statistics", False): used_indices = get_used_indices(model, dset, batch_size=batch_size) visualize3d(codebook, used_indices) if st.checkbox("Show Batch Encodings", False): used_indices = get_used_indices(model, dset, batch_size=batch_size) get_image_embeddings(model, dset, codebook[used_indices, :], used_indices) if st.checkbox("Show Fixed Points of Data", False): get_fixed_points(model, dset) else: st.info("Detected a KL model") get_fixed_points_kl_ae(model, dset)