add upscaling

This commit is contained in:
rromb 2022-06-13 00:39:48 +02:00
parent a0674ac4a2
commit c5a39aff8a
5 changed files with 282 additions and 4 deletions

View file

@ -0,0 +1,175 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
params:
low_scale_key: "LR_image" # TODO: adapt
linear_start: 0.001
linear_end: 0.015
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "image"
#first_stage_key: "jpg" # TODO: use this later
cond_stage_key: "caption"
#cond_stage_key: "txt" # TODO: use this later
image_size: 64
channels: 16
cond_stage_trainable: false
conditioning_key: "hybrid-adm"
monitor: val/loss_simple_ema
scale_factor: 0.22765929 # magic number
low_scale_config:
target: ldm.modules.encoders.modules.LowScaleEncoder
params:
linear_start: 0.00085
linear_end: 0.0120
timesteps: 1000
max_noise_level: 250
output_size: 64
model_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
num_classes: 1000 # timesteps for noise conditoining
image_size: 64 # not really needed
in_channels: 20
out_channels: 16
model_channels: 32 # TODO: more
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 16
monitor: val/rec_loss
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
ddconfig:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ 16 ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
#data:
# target: ldm.data.laion.WebDataModuleFromConfig
# params:
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
# batch_size: 4
# num_workers: 4
# multinode: True
# min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
# train:
# shards: '{000000..231317}.tar -'
# shuffle: 10000
# image_key: jpg
# image_transforms:
# - target: torchvision.transforms.Resize
# params:
# size: 1024
# interpolation: 3
# - target: torchvision.transforms.RandomCrop
# params:
# size: 1024
#
# # NOTE use enough shards to avoid empty validation loops in workers
# validation:
# shards: '{231318..231349}.tar -'
# shuffle: 0
# image_key: jpg
# image_transforms:
# - target: torchvision.transforms.Resize
# params:
# size: 1024
# interpolation: 3
# - target: torchvision.transforms.CenterCrop
# params:
# size: 1024
data:
target: main.DataModuleFromConfig
params:
batch_size: 8
num_workers: 7
wrap: false
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 1024
downscale_f: 4
degradation: "cv_nearest"
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 10
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]
trainer:
benchmark: True
# val_check_interval: 5000000 # really sorry # TODO: bring back in
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View file

@ -368,7 +368,7 @@ class ImageNetSR(Dataset):
example["image"] = (image/127.5 - 1.0).astype(np.float32) example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
example["caption"] = example["human_label"] # dummy caption
return example return example

View file

@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
@torch.no_grad() @torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None): cond_key=None, return_original_cond=False, bs=None, return_x=False):
x = super().get_input(batch, k) x = super().get_input(batch, k)
if bs is not None: if bs is not None:
x = x[:bs] x = x[:bs]
@ -746,6 +746,8 @@ class LatentDiffusion(DDPM):
if return_first_stage_outputs: if return_first_stage_outputs:
xrec = self.decode_first_stage(z) xrec = self.decode_first_stage(z)
out.extend([x, xrec]) out.extend([x, xrec])
if return_x:
out.extend([x])
if return_original_cond: if return_original_cond:
out.append(xc) out.append(xc)
return out return out
@ -1416,9 +1418,9 @@ class DiffusionWrapper(pl.LightningModule):
super().__init__() super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config) self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t) out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':
@ -1431,6 +1433,11 @@ class DiffusionWrapper(pl.LightningModule):
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc) out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc) out = self.diffusion_model(x, t, y=cc)
@ -1440,6 +1447,34 @@ class DiffusionWrapper(pl.LightningModule):
return out return out
class LatentUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None):
z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs)
x_low = batch[self.low_scale_key]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low)
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
return z, all_conds
# TODO log it
class Layout2ImgDiffusion(LatentDiffusion): class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class # TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs): def __init__(self, cond_stage_key, *args, **kwargs):

View file

@ -1,8 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from functools import partial from functools import partial
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.util import default
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
@ -201,6 +203,60 @@ class SpatialRescaler(nn.Module):
return self(x) return self(x)
from ldm.util import instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
class LowScaleEncoder(nn.Module):
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
z = self.model.encode(x).sample()
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
z = self.q_sample(z, noise_level)
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
if __name__ == "__main__": if __name__ == "__main__":
from ldm.util import count_params from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]

View file

@ -0,0 +1,12 @@
# TODO, check out Twitter.
Darth Vader at Woodstock (1969)
Bunny Vikings
The Demogorgon from Stranger Thinhs holding a basketball
Hamster in my microwave
a courtroom sketch of a Ford Transit van
PS1 Hagrid ad MCDonalds
Karl Marx in KFC Logo
Moai Statue giving a TED talk
wahing machine trail cam
minions at cross burning
Hindenburg disaster in Fortnite