add upscaling
This commit is contained in:
parent
a0674ac4a2
commit
c5a39aff8a
5 changed files with 282 additions and 4 deletions
|
@ -0,0 +1,175 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
low_scale_key: "LR_image" # TODO: adapt
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
#first_stage_key: "jpg" # TODO: use this later
|
||||||
|
cond_stage_key: "caption"
|
||||||
|
#cond_stage_key: "txt" # TODO: use this later
|
||||||
|
image_size: 64
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
timesteps: 1000
|
||||||
|
max_noise_level: 250
|
||||||
|
output_size: 64
|
||||||
|
model_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: 1000 # timesteps for noise conditoining
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 20
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 32 # TODO: more
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
#data:
|
||||||
|
# target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
# params:
|
||||||
|
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
# batch_size: 4
|
||||||
|
# num_workers: 4
|
||||||
|
# multinode: True
|
||||||
|
# min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||||
|
# train:
|
||||||
|
# shards: '{000000..231317}.tar -'
|
||||||
|
# shuffle: 10000
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.RandomCrop
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
#
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
# validation:
|
||||||
|
# shards: '{231318..231349}.tar -'
|
||||||
|
# shuffle: 0
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.CenterCrop
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 7
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
downscale_f: 4
|
||||||
|
degradation: "cv_nearest"
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 10
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
# val_check_interval: 5000000 # really sorry # TODO: bring back in
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
|
@ -368,7 +368,7 @@ class ImageNetSR(Dataset):
|
||||||
|
|
||||||
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||||
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||||
|
example["caption"] = example["human_label"] # dummy caption
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
||||||
cond_key=None, return_original_cond=False, bs=None):
|
cond_key=None, return_original_cond=False, bs=None, return_x=False):
|
||||||
x = super().get_input(batch, k)
|
x = super().get_input(batch, k)
|
||||||
if bs is not None:
|
if bs is not None:
|
||||||
x = x[:bs]
|
x = x[:bs]
|
||||||
|
@ -746,6 +746,8 @@ class LatentDiffusion(DDPM):
|
||||||
if return_first_stage_outputs:
|
if return_first_stage_outputs:
|
||||||
xrec = self.decode_first_stage(z)
|
xrec = self.decode_first_stage(z)
|
||||||
out.extend([x, xrec])
|
out.extend([x, xrec])
|
||||||
|
if return_x:
|
||||||
|
out.extend([x])
|
||||||
if return_original_cond:
|
if return_original_cond:
|
||||||
out.append(xc)
|
out.append(xc)
|
||||||
return out
|
return out
|
||||||
|
@ -1416,9 +1418,9 @@ class DiffusionWrapper(pl.LightningModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.diffusion_model = instantiate_from_config(diff_model_config)
|
self.diffusion_model = instantiate_from_config(diff_model_config)
|
||||||
self.conditioning_key = conditioning_key
|
self.conditioning_key = conditioning_key
|
||||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
|
||||||
|
|
||||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t)
|
out = self.diffusion_model(x, t)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
|
@ -1431,6 +1433,11 @@ class DiffusionWrapper(pl.LightningModule):
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc)
|
out = self.diffusion_model(xc, t, context=cc)
|
||||||
|
elif self.conditioning_key == 'hybrid-adm':
|
||||||
|
assert c_adm is not None
|
||||||
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
|
cc = torch.cat(c_crossattn, 1)
|
||||||
|
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
||||||
elif self.conditioning_key == 'adm':
|
elif self.conditioning_key == 'adm':
|
||||||
cc = c_crossattn[0]
|
cc = c_crossattn[0]
|
||||||
out = self.diffusion_model(x, t, y=cc)
|
out = self.diffusion_model(x, t, y=cc)
|
||||||
|
@ -1440,6 +1447,34 @@ class DiffusionWrapper(pl.LightningModule):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpscaleDiffusion(LatentDiffusion):
|
||||||
|
def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
|
||||||
|
assert not self.cond_stage_trainable
|
||||||
|
self.instantiate_low_stage(low_scale_config)
|
||||||
|
self.low_scale_key = low_scale_key
|
||||||
|
|
||||||
|
def instantiate_low_stage(self, config):
|
||||||
|
model = instantiate_from_config(config)
|
||||||
|
self.low_scale_model = model.eval()
|
||||||
|
self.low_scale_model.train = disabled_train
|
||||||
|
for param in self.low_scale_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(self, batch, k, cond_key=None, bs=None):
|
||||||
|
z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs)
|
||||||
|
x_low = batch[self.low_scale_key]
|
||||||
|
x_low = rearrange(x_low, 'b h w c -> b c h w')
|
||||||
|
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
||||||
|
zx, noise_level = self.low_scale_model(x_low)
|
||||||
|
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
||||||
|
return z, all_conds
|
||||||
|
|
||||||
|
# TODO log it
|
||||||
|
|
||||||
|
|
||||||
class Layout2ImgDiffusion(LatentDiffusion):
|
class Layout2ImgDiffusion(LatentDiffusion):
|
||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
|
from ldm.util import default
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
class AbstractEncoder(nn.Module):
|
||||||
|
@ -201,6 +203,60 @@ class SpatialRescaler(nn.Module):
|
||||||
return self(x)
|
return self(x)
|
||||||
|
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||||
|
|
||||||
|
|
||||||
|
class LowScaleEncoder(nn.Module):
|
||||||
|
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64):
|
||||||
|
super().__init__()
|
||||||
|
self.max_noise_level = max_noise_level
|
||||||
|
self.model = instantiate_from_config(model_config)
|
||||||
|
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||||
|
linear_end=linear_end)
|
||||||
|
self.out_size = output_size
|
||||||
|
|
||||||
|
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||||
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||||
|
cosine_s=cosine_s)
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
|
||||||
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z = self.model.encode(x).sample()
|
||||||
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
|
z = self.q_sample(z, noise_level)
|
||||||
|
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||||
|
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||||
|
return z, noise_level
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from ldm.util import count_params
|
from ldm.util import count_params
|
||||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||||
|
|
12
scripts/prompts/weird-dalle-prompts.txt
Normal file
12
scripts/prompts/weird-dalle-prompts.txt
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# TODO, check out Twitter.
|
||||||
|
Darth Vader at Woodstock (1969)
|
||||||
|
Bunny Vikings
|
||||||
|
The Demogorgon from Stranger Thinhs holding a basketball
|
||||||
|
Hamster in my microwave
|
||||||
|
a courtroom sketch of a Ford Transit van
|
||||||
|
PS1 Hagrid ad MCDonalds
|
||||||
|
Karl Marx in KFC Logo
|
||||||
|
Moai Statue giving a TED talk
|
||||||
|
wahing machine trail cam
|
||||||
|
minions at cross burning
|
||||||
|
Hindenburg disaster in Fortnite
|
Loading…
Reference in a new issue