add upscaling
This commit is contained in:
parent
a0674ac4a2
commit
c5a39aff8a
5 changed files with 282 additions and 4 deletions
|
@ -0,0 +1,175 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||
params:
|
||||
low_scale_key: "LR_image" # TODO: adapt
|
||||
linear_start: 0.001
|
||||
linear_end: 0.015
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "image"
|
||||
#first_stage_key: "jpg" # TODO: use this later
|
||||
cond_stage_key: "caption"
|
||||
#cond_stage_key: "txt" # TODO: use this later
|
||||
image_size: 64
|
||||
channels: 16
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: "hybrid-adm"
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.22765929 # magic number
|
||||
|
||||
low_scale_config:
|
||||
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
timesteps: 1000
|
||||
max_noise_level: 250
|
||||
output_size: 64
|
||||
model_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: 1000 # timesteps for noise conditoining
|
||||
image_size: 64 # not really needed
|
||||
in_channels: 20
|
||||
out_channels: 16
|
||||
model_channels: 32 # TODO: more
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 16
|
||||
monitor: val/rec_loss
|
||||
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ 16 ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
|
||||
#data:
|
||||
# target: ldm.data.laion.WebDataModuleFromConfig
|
||||
# params:
|
||||
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||
# batch_size: 4
|
||||
# num_workers: 4
|
||||
# multinode: True
|
||||
# min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||
# train:
|
||||
# shards: '{000000..231317}.tar -'
|
||||
# shuffle: 10000
|
||||
# image_key: jpg
|
||||
# image_transforms:
|
||||
# - target: torchvision.transforms.Resize
|
||||
# params:
|
||||
# size: 1024
|
||||
# interpolation: 3
|
||||
# - target: torchvision.transforms.RandomCrop
|
||||
# params:
|
||||
# size: 1024
|
||||
#
|
||||
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||
# validation:
|
||||
# shards: '{231318..231349}.tar -'
|
||||
# shuffle: 0
|
||||
# image_key: jpg
|
||||
# image_transforms:
|
||||
# - target: torchvision.transforms.Resize
|
||||
# params:
|
||||
# size: 1024
|
||||
# interpolation: 3
|
||||
# - target: torchvision.transforms.CenterCrop
|
||||
# params:
|
||||
# size: 1024
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 8
|
||||
num_workers: 7
|
||||
wrap: false
|
||||
train:
|
||||
target: ldm.data.imagenet.ImageNetSRTrain
|
||||
params:
|
||||
size: 1024
|
||||
downscale_f: 4
|
||||
degradation: "cv_nearest"
|
||||
|
||||
lightning:
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 10
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 3.0
|
||||
unconditional_guidance_label: [""]
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
# val_check_interval: 5000000 # really sorry # TODO: bring back in
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
|
@ -368,7 +368,7 @@ class ImageNetSR(Dataset):
|
|||
|
||||
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["caption"] = example["human_label"] # dummy caption
|
||||
return example
|
||||
|
||||
|
||||
|
|
|
@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
|
|||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
||||
cond_key=None, return_original_cond=False, bs=None):
|
||||
cond_key=None, return_original_cond=False, bs=None, return_x=False):
|
||||
x = super().get_input(batch, k)
|
||||
if bs is not None:
|
||||
x = x[:bs]
|
||||
|
@ -746,6 +746,8 @@ class LatentDiffusion(DDPM):
|
|||
if return_first_stage_outputs:
|
||||
xrec = self.decode_first_stage(z)
|
||||
out.extend([x, xrec])
|
||||
if return_x:
|
||||
out.extend([x])
|
||||
if return_original_cond:
|
||||
out.append(xc)
|
||||
return out
|
||||
|
@ -1416,9 +1418,9 @@ class DiffusionWrapper(pl.LightningModule):
|
|||
super().__init__()
|
||||
self.diffusion_model = instantiate_from_config(diff_model_config)
|
||||
self.conditioning_key = conditioning_key
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
|
||||
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
elif self.conditioning_key == 'concat':
|
||||
|
@ -1431,6 +1433,11 @@ class DiffusionWrapper(pl.LightningModule):
|
|||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
elif self.conditioning_key == 'hybrid-adm':
|
||||
assert c_adm is not None
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
out = self.diffusion_model(x, t, y=cc)
|
||||
|
@ -1440,6 +1447,34 @@ class DiffusionWrapper(pl.LightningModule):
|
|||
return out
|
||||
|
||||
|
||||
class LatentUpscaleDiffusion(LatentDiffusion):
|
||||
def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
|
||||
assert not self.cond_stage_trainable
|
||||
self.instantiate_low_stage(low_scale_config)
|
||||
self.low_scale_key = low_scale_key
|
||||
|
||||
def instantiate_low_stage(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.low_scale_model = model.eval()
|
||||
self.low_scale_model.train = disabled_train
|
||||
for param in self.low_scale_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, cond_key=None, bs=None):
|
||||
z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs)
|
||||
x_low = batch[self.low_scale_key]
|
||||
x_low = rearrange(x_low, 'b h w c -> b c h w')
|
||||
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
||||
zx, noise_level = self.low_scale_model(x_low)
|
||||
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
||||
return z, all_conds
|
||||
|
||||
# TODO log it
|
||||
|
||||
|
||||
class Layout2ImgDiffusion(LatentDiffusion):
|
||||
# TODO: move all layout-specific hacks to this class
|
||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
from ldm.util import default
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
@ -201,6 +203,60 @@ class SpatialRescaler(nn.Module):
|
|||
return self(x)
|
||||
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.out_size = output_size
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||
|
|
12
scripts/prompts/weird-dalle-prompts.txt
Normal file
12
scripts/prompts/weird-dalle-prompts.txt
Normal file
|
@ -0,0 +1,12 @@
|
|||
# TODO, check out Twitter.
|
||||
Darth Vader at Woodstock (1969)
|
||||
Bunny Vikings
|
||||
The Demogorgon from Stranger Thinhs holding a basketball
|
||||
Hamster in my microwave
|
||||
a courtroom sketch of a Ford Transit van
|
||||
PS1 Hagrid ad MCDonalds
|
||||
Karl Marx in KFC Logo
|
||||
Moai Statue giving a TED talk
|
||||
wahing machine trail cam
|
||||
minions at cross burning
|
||||
Hindenburg disaster in Fortnite
|
Loading…
Reference in a new issue