ae eval script
This commit is contained in:
parent
36b5177221
commit
6dc939330d
1 changed files with 542 additions and 0 deletions
542
scripts/vqgan_codebook_visualizer.py
Normal file
542
scripts/vqgan_codebook_visualizer.py
Normal file
|
@ -0,0 +1,542 @@
|
|||
import argparse, os, sys, glob
|
||||
import torch
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import matplotlib.pyplot as plt
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from main import instantiate_from_config, DataModuleFromConfig
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
rescale = lambda x: (x + 1.) / 2.
|
||||
|
||||
|
||||
def bchw_to_st(x):
|
||||
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
|
||||
|
||||
def chw_to_st(x):
|
||||
return rescale(x.detach().cpu().numpy().transpose(1,2,0))
|
||||
|
||||
|
||||
def custom_get_input(batch, key):
|
||||
inputs = batch[key].permute(0, 3, 1, 2)
|
||||
return inputs
|
||||
|
||||
|
||||
def vq_no_codebook_forward(model, x):
|
||||
h = model.encoder(x)
|
||||
h = model.quant_conv(h)
|
||||
h = model.post_quant_conv(h)
|
||||
xrec = model.decoder(h)
|
||||
return xrec
|
||||
|
||||
|
||||
def save_img(x, fname):
|
||||
I = (x.clip(0, 1) * 255).astype(np.uint8)
|
||||
Image.fromarray(I).save(fname)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resume",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="load from logdir or checkpoint in logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--base",
|
||||
nargs="*",
|
||||
metavar="base_config.yaml",
|
||||
help="paths to base configs. Loaded from left-to-right. "
|
||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||
default=list(),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
nargs="?",
|
||||
metavar="single_config.yaml",
|
||||
help="path to single config. If specified, base configs will be ignored "
|
||||
"(except for the last one if left unspecified).",
|
||||
const=True,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="",
|
||||
help="path to dataset config"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
||||
model = instantiate_from_config(config)
|
||||
if sd is not None:
|
||||
model.load_state_dict(sd)
|
||||
if gpu:
|
||||
model.cuda()
|
||||
if eval_mode:
|
||||
model.eval()
|
||||
return {"model": model}
|
||||
|
||||
|
||||
def get_data(config):
|
||||
# get data
|
||||
data = instantiate_from_config(config.data)
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
return data
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_model_and_dset(config, ckpt, gpu, eval_mode, delete_dataset_params=False):
|
||||
# get data
|
||||
if delete_dataset_params:
|
||||
st.info("Deleting dataset parameters.")
|
||||
del config["data"]["params"]["train"]["params"]
|
||||
del config["data"]["params"]["validation"]["params"]
|
||||
|
||||
dsets = get_data(config) # calls data.config ...
|
||||
|
||||
# now load the specified checkpoint
|
||||
if ckpt:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
else:
|
||||
pl_sd = {"state_dict": None}
|
||||
global_step = None
|
||||
model = load_model_from_config(config.model,
|
||||
pl_sd["state_dict"],
|
||||
gpu=gpu,
|
||||
eval_mode=eval_mode)["model"]
|
||||
return dsets, model, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embeddings(model, dset, used_codebook, used_indices):
|
||||
import plotly.graph_objects as go
|
||||
batch_size = st.number_input("Batch size for embedding visualization", min_value=1, value=4)
|
||||
start_index = st.number_input("Start index", value=0,
|
||||
min_value=0,
|
||||
max_value=len(dset) - batch_size)
|
||||
if st.sidebar.button("Sample Batch"):
|
||||
indices = np.random.choice(len(dset), batch_size)
|
||||
else:
|
||||
indices = list(range(start_index, start_index + batch_size))
|
||||
|
||||
st.write(f"Indices: {indices}")
|
||||
batch = default_collate([dset[i] for i in indices])
|
||||
x = model.get_input(batch, "image")
|
||||
x = x.to(model.device)
|
||||
|
||||
# get reconstruction from non-quantized and quantized, compare
|
||||
z_pre_quant = model.encode_to_prequant(x)
|
||||
z_quant, emb_loss, info = model.quantize(z_pre_quant)
|
||||
indices = info[2].detach().cpu().numpy()
|
||||
#indices = rearrange(indices, '(b d) -> b d', b=batch_size)
|
||||
unique_indices = np.unique(indices)
|
||||
st.write(f"Unique indices in batch: {unique_indices.shape[0]}")
|
||||
|
||||
x1 = used_codebook[:, 0].cpu().numpy()
|
||||
x2 = used_codebook[:, 1].cpu().numpy()
|
||||
x3 = used_codebook[:, 2].cpu().numpy()
|
||||
|
||||
zp1 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
|
||||
zp2 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
|
||||
zp3 = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
|
||||
|
||||
zq1 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 0].cpu().numpy()
|
||||
zq2 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 1].cpu().numpy()
|
||||
zq3 = rearrange(z_quant, 'b c h w -> (b h w) c')[:, 2].cpu().numpy()
|
||||
|
||||
fig = go.Figure(data=[go.Scatter3d(x=x1, y=x2, z=x3, mode='markers', marker=dict(size=1.4, line=dict(width=1.,
|
||||
color="Blue")),
|
||||
name="All Used Codebook Entries"),
|
||||
|
||||
])
|
||||
trace2 = go.Scatter3d(x=zp1, y=zp2, z=zp3, mode='markers', marker=dict(size=1., line=dict(width=1., color=indices)),
|
||||
name="Pre-Quant Codes")
|
||||
trace3 = go.Scatter3d(x=zq1, y=zq2, z=zq3, mode='markers',
|
||||
marker=dict(size=2., line=dict(width=10., color=indices)), name="Quantized Codes")
|
||||
|
||||
fig.add_trace(trace2)
|
||||
fig.add_trace(trace3)
|
||||
|
||||
fig.update_layout(
|
||||
autosize=False,
|
||||
width=1000,
|
||||
height=1000,
|
||||
)
|
||||
|
||||
x_rec_no_quant = model.decode(z_pre_quant)
|
||||
x_rec_quant = model.decode(z_quant)
|
||||
delta_x = x_rec_no_quant - x_rec_quant
|
||||
|
||||
st.text("Fitting Gaussian...")
|
||||
h, w = z_quant.shape[2], z_quant.shape[3]
|
||||
from sklearn.mixture import GaussianMixture
|
||||
gaussian = GaussianMixture(n_components=1)
|
||||
gaussian.fit(rearrange(z_pre_quant, 'b c h w -> (b h w) c').cpu().numpy())
|
||||
samples, _ = gaussian.sample(n_samples=batch_size*h*w)
|
||||
samples = rearrange(samples, '(b h w) c -> b h w c', b=batch_size, h=h, w=w, c=3)
|
||||
samples = rearrange(samples, 'b h w c -> b c h w')
|
||||
samples = torch.tensor(samples).to(z_quant)
|
||||
samples, _, _ = model.quantize(samples)
|
||||
x_sample = model.decode(samples)
|
||||
|
||||
all_img = torch.stack([x, x_rec_quant, x_rec_no_quant, delta_x, x_sample]) # 5 b 3 H W
|
||||
|
||||
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
|
||||
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
|
||||
grid = make_grid(all_img, nrow=5)
|
||||
|
||||
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
|
||||
st.image(chw_to_st(grid), clamp=True, output_format="PNG")
|
||||
|
||||
st.write(fig)
|
||||
# 2d projections
|
||||
import matplotlib.pyplot as plt
|
||||
pairs = [(1, 0), (2, 0), (2, 1)]
|
||||
fig2, ax = plt.subplots(1, 3, figsize=(21, 7))
|
||||
for d in range(3):
|
||||
d1, d2 = pairs[d]
|
||||
#ax[d].scatter(used_codebook[:, d1].cpu().numpy(),
|
||||
# used_codebook[:, d2].cpu().numpy(),
|
||||
# label="All Used Codebook Entries", s=10.0, c=used_indices)
|
||||
ax[d].scatter(rearrange(z_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
|
||||
rearrange(z_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
|
||||
label="Quantized Codes", alpha=0.9, s=8.0, c=indices)
|
||||
ax[d].scatter(rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d1].cpu().numpy(),
|
||||
rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, d2].cpu().numpy(),
|
||||
label="Pre-Quant Codes", alpha=0.5, s=1.0, c=indices)
|
||||
ax[d].set_title(f"dim {d2} vs dim {d1}")
|
||||
ax[d].legend()
|
||||
st.write(fig2)
|
||||
|
||||
plt.close()
|
||||
|
||||
# from scipy.spatial import Voronoi, voronoi_plot_2d
|
||||
# fig3 = plt.figure(figsize=(10,10))
|
||||
# points = rearrange(z_pre_quant, 'b c h w -> (b h w) c')[:, :2].cpu().numpy()
|
||||
# vor = Voronoi(points)
|
||||
# # plot
|
||||
# voronoi_plot_2d(vor)
|
||||
# # colorize
|
||||
# for region in vor.regions:
|
||||
# if not -1 in region:
|
||||
# polygon = [vor.vertices[i] for i in region]
|
||||
# plt.fill(*zip(*polygon))
|
||||
# plt.savefig("voronoi_test.png")
|
||||
# st.write(fig3)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_used_indices(model, dset, batch_size=20):
|
||||
dloader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=batch_size, drop_last=False)
|
||||
data = list()
|
||||
info = st.empty()
|
||||
for i, batch in enumerate(dloader):
|
||||
x = model.get_input(batch, "image")
|
||||
x = x.to(model.device)
|
||||
zq, _, zi = model.encode(x)
|
||||
indices = zi[2]
|
||||
indices = indices.reshape(zq.shape[0], -1).detach().cpu().numpy()
|
||||
data.append(indices)
|
||||
|
||||
unique = np.unique(data)
|
||||
info.text(f"iteration {i} [{batch_size*i}/{len(dset)}]: unique indices found so far: {unique.size}")
|
||||
|
||||
unique = np.unique(data)
|
||||
#np.save(outpath, unique)
|
||||
st.write(f"end of data: found **{unique.size} unique indices.**")
|
||||
print(f"end of data: found {unique.size} unique indices.")
|
||||
return unique
|
||||
|
||||
|
||||
def visualize3d(codebook, used_indices):
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
codebook = codebook.cpu().numpy()
|
||||
selected_codebook = codebook[used_indices, :]
|
||||
z_dim = codebook.shape[1]
|
||||
assert z_dim == 3
|
||||
|
||||
pairs = [(1,0), (2,0), (2,1)]
|
||||
fig, ax = plt.subplots(1,3, figsize=(15,5))
|
||||
for d in range(3):
|
||||
d1, d2 = pairs[d]
|
||||
ax[d].scatter(selected_codebook[:, d1], selected_codebook[:, d2])
|
||||
ax[d].set_title(f"dim {d2} vs dim {d1}")
|
||||
st.write(fig)
|
||||
|
||||
# # plot 3D
|
||||
# fig = plt.figure(1)
|
||||
# ax = Axes3D(fig)
|
||||
# ax.scatter(codebook[:, 0], codebook[:, 1], codebook[:, 2], s=10., alpha=0.8, label="all entries")
|
||||
# ax.scatter(selected_codebook[:, 0], selected_codebook[:, 1], selected_codebook[:, 2], s=3., alpha=1.0, label="used entries")
|
||||
# plt.legend()
|
||||
# #st.write(fig)
|
||||
# st.pyplot(fig)
|
||||
|
||||
# plot histogram of vector norms
|
||||
fig = plt.figure(2, figsize=(6,5))
|
||||
norms = np.linalg.norm(selected_codebook, axis=1)
|
||||
plt.hist(norms, bins=100, edgecolor="black", lw=1.1)
|
||||
plt.title("Distribution of norms of used codebook entries")
|
||||
st.write(fig)
|
||||
|
||||
# plot 3D with plotly
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
x = selected_codebook[:, 0]
|
||||
y = selected_codebook[:, 1]
|
||||
z = selected_codebook[:, 2]
|
||||
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2., line=dict(width=1.,
|
||||
color="Blue"))
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
autosize=False,
|
||||
width=1000,
|
||||
height=1000,
|
||||
)
|
||||
st.write(fig)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_fixed_points(model, dset):
|
||||
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
|
||||
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
|
||||
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
|
||||
clip_decoded = st.checkbox("Clip decoded image", False)
|
||||
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
|
||||
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
|
||||
if st.sidebar.button("Sample Batch"):
|
||||
indices = np.random.choice(len(dset), batch_size)
|
||||
else:
|
||||
indices = list(range(start_index, start_index + batch_size))
|
||||
|
||||
st.write(f"Indices: {indices}")
|
||||
batch = default_collate([dset[i] for i in indices])
|
||||
x = model.get_input(batch, "image")
|
||||
x = x.to(model.device)
|
||||
|
||||
progress = st.empty()
|
||||
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
|
||||
|
||||
image_progress = st.empty() # TODO
|
||||
|
||||
input = x
|
||||
img_quant = x
|
||||
img_noquant = x
|
||||
delta_img = img_quant - img_noquant
|
||||
|
||||
st.write("** Input | Rec. (w/ quant) | Rec. (no quant) | Delta(quant, no_quant) **")
|
||||
|
||||
def display(input, img_quant, img_noquant, delta_img):
|
||||
all_img = torch.stack([input, img_quant, img_noquant, delta_img]) # 4 b 3 H W
|
||||
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
|
||||
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
|
||||
grid = make_grid(all_img, nrow=4)
|
||||
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
|
||||
|
||||
display(input, img_quant, img_noquant, delta_img)
|
||||
for n in range(n_iter):
|
||||
# get reconstruction from non-quantized and quantized, compare via iteration
|
||||
|
||||
# quantized_stream
|
||||
z_pre_quant = model.encode_to_prequant(img_quant)
|
||||
z_quant, emb_loss, info = model.quantize(z_pre_quant)
|
||||
|
||||
# non_quantized stream
|
||||
z_noquant = model.encode_to_prequant(img_noquant)
|
||||
|
||||
img_quant = model.decode(z_quant)
|
||||
img_noquant = model.decode(z_noquant)
|
||||
if clip_decoded:
|
||||
img_quant = torch.clamp(img_quant, -1., 1.)
|
||||
img_noquant = torch.clamp(img_noquant, -1., 1.)
|
||||
if quantize_decoded:
|
||||
device = img_quant.device
|
||||
img_quant = (2*torch.Tensor(((img_quant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
|
||||
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(device)
|
||||
delta_img = img_quant - img_noquant
|
||||
display(input, img_quant, img_noquant, delta_img)
|
||||
progress_cb(n + 1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_fixed_points_kl_ae(model, dset):
|
||||
n_iter = st.number_input("Number of Iterations for FP-Analysis", min_value=1, value=25)
|
||||
batch_size = st.number_input("Batch size for fixed-point visualization", min_value=1, value=4)
|
||||
start_index = st.number_input("Start index", value=0, min_value=0, max_value=len(dset) - batch_size)
|
||||
clip_decoded = st.checkbox("Clip decoded image", False)
|
||||
quantize_decoded = st.checkbox("Quantize decoded image (e.g. map back to uint8)", False)
|
||||
sample_posterior = st.checkbox("Sample from encoder posterior", False)
|
||||
factor = st.sidebar.number_input("image size", value=1., min_value=0.1)
|
||||
if st.sidebar.button("Sample Batch"):
|
||||
indices = np.random.choice(len(dset), batch_size)
|
||||
else:
|
||||
indices = list(range(start_index, start_index + batch_size))
|
||||
|
||||
st.write(f"Indices: {indices}")
|
||||
batch = default_collate([dset[i] for i in indices])
|
||||
x = model.get_input(batch, "image")
|
||||
x = x.to(model.device)
|
||||
|
||||
progress = st.empty()
|
||||
progress_cb = lambda k: progress.write(f"iteration {k}/{n_iter}")
|
||||
|
||||
st.write("** Input | Rec. (no quant) | Delta(input, iter_rec) **")
|
||||
image_progress = st.empty()
|
||||
|
||||
input = x
|
||||
img_noquant = x
|
||||
delta_img = input - img_noquant
|
||||
|
||||
def display(input, img_noquant, delta_img):
|
||||
all_img = torch.stack([input, img_noquant, delta_img]) # 3 b 3 H W
|
||||
all_img = rearrange(all_img, 'n b c h w -> b n c h w')
|
||||
all_img = rearrange(all_img, 'b n c h w -> (b n) c h w')
|
||||
grid = make_grid(all_img, nrow=3)
|
||||
image_progress.image(chw_to_st(grid), clamp=True, output_format="PNG", width=int(factor*grid.shape[2]))
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
distribution_progress = st.empty()
|
||||
def display_latent_distribution(latent_z, alpha=1., title=""):
|
||||
flatz = latent_z.reshape(-1).cpu().detach().numpy()
|
||||
#fig, ax = plt.subplots()
|
||||
ax.hist(flatz, bins=42, alpha=alpha, lw=.1, edgecolor="black")
|
||||
ax.set_title(title)
|
||||
distribution_progress.pyplot(fig)
|
||||
|
||||
display(input, img_noquant, delta_img)
|
||||
for n in range(n_iter):
|
||||
# get reconstructions
|
||||
|
||||
posterior = model.encode(img_noquant)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
|
||||
if n==0:
|
||||
flatz_init = z.reshape(-1).cpu().detach().numpy()
|
||||
std_init = flatz_init.std()
|
||||
max_init, min_init = flatz_init.max(), flatz_init.min()
|
||||
|
||||
display_latent_distribution(z, alpha=np.sqrt(1/(n+1)),
|
||||
title=f"initial z: std/min/max: {std_init:.2f}/{min_init:.2f}/{max_init:.2f}")
|
||||
|
||||
img_noquant = model.decode(z)
|
||||
if clip_decoded:
|
||||
img_noquant = torch.clamp(img_noquant, -1., 1.)
|
||||
if quantize_decoded:
|
||||
img_noquant = (2*torch.Tensor(((img_noquant.cpu().numpy()+1.)*127.5).astype(np.uint8))/255. - 1.).to(model.device)
|
||||
delta_img = img_noquant - input
|
||||
display(input, img_noquant, delta_img)
|
||||
progress_cb(n + 1)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
# VISUALIZE USED AND ALL INDICES of VQ-Model. VISUALIZE FIXED POINTS OF KL MODEL
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
parser = get_parser()
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
ckpt = None
|
||||
if opt.resume:
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
paths = opt.resume.split("/")
|
||||
try:
|
||||
idx = len(paths)-paths[::-1].index("logs")+1
|
||||
except ValueError:
|
||||
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
||||
logdir = "/".join(paths[:idx])
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), opt.resume
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
||||
opt.base = base_configs+opt.base
|
||||
|
||||
if opt.config:
|
||||
if type(opt.config) == str:
|
||||
opt.base = [opt.config]
|
||||
else:
|
||||
opt.base = [opt.base[-1]]
|
||||
|
||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
|
||||
if opt.dataset_config:
|
||||
dcfg = OmegaConf.load(opt.dataset_config)
|
||||
print("Replacing data config with:")
|
||||
print(dcfg.pretty())
|
||||
dcfg = OmegaConf.to_container(dcfg)
|
||||
config["data"] = dcfg["data"]
|
||||
|
||||
st.sidebar.text(ckpt)
|
||||
gs = st.sidebar.empty()
|
||||
gs.text(f"Global step: ?")
|
||||
st.sidebar.text("Options")
|
||||
gpu = st.sidebar.checkbox("GPU", value=True)
|
||||
eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
|
||||
show_config = st.sidebar.checkbox("Show Config", value=False)
|
||||
if show_config:
|
||||
st.info("Checkpoint: {}".format(ckpt))
|
||||
st.json(OmegaConf.to_container(config))
|
||||
|
||||
delelete_dataset_parameters = st.sidebar.checkbox("Delete parameters of dataset.")
|
||||
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode,
|
||||
delete_dataset_params=delelete_dataset_parameters)
|
||||
gs.text(f"Global step: {global_step}")
|
||||
|
||||
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())[::-1])
|
||||
dset = dsets.datasets[split]
|
||||
|
||||
batch_size = st.sidebar.number_input("Batch size", min_value=1, value=20)
|
||||
num_batches = st.sidebar.number_input("Number of batches", min_value=1, value=5)
|
||||
data_size = batch_size*num_batches
|
||||
dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), size=(data_size,), replace=False))
|
||||
|
||||
if not isinstance(model, AutoencoderKL):
|
||||
# VQ MODEL
|
||||
codebook = model.quantize.embedding.weight.data
|
||||
st.write(f"VQ-Model has codebook of dimensionality **{codebook.shape[0]} x {codebook.shape[1]} (num_entries x z_dim)**")
|
||||
st.write(f"Evaluating codebook-usage on **{config['data']['params'][split]['target']}**")
|
||||
st.write("**Select ONE of the following options**")
|
||||
|
||||
if st.checkbox("Show Codebook Statistics", False):
|
||||
used_indices = get_used_indices(model, dset, batch_size=batch_size)
|
||||
visualize3d(codebook, used_indices)
|
||||
if st.checkbox("Show Batch Encodings", False):
|
||||
used_indices = get_used_indices(model, dset, batch_size=batch_size)
|
||||
get_image_embeddings(model, dset, codebook[used_indices, :], used_indices)
|
||||
if st.checkbox("Show Fixed Points of Data", False):
|
||||
get_fixed_points(model, dset)
|
||||
|
||||
else:
|
||||
st.info("Detected a KL model")
|
||||
get_fixed_points_kl_ae(model, dset)
|
Loading…
Reference in a new issue