From c5a39aff8a35c129ab51dc1b7206d95f3d3f529f Mon Sep 17 00:00:00 2001 From: rromb Date: Mon, 13 Jun 2022 00:39:48 +0200 Subject: [PATCH] add upscaling --- ...img-upscale-clip-encoder-f16-1024-dev.yaml | 175 ++++++++++++++++++ ldm/data/imagenet.py | 2 +- ldm/models/diffusion/ddpm.py | 41 +++- ldm/modules/encoders/modules.py | 56 ++++++ scripts/prompts/weird-dalle-prompts.txt | 12 ++ 5 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml create mode 100644 scripts/prompts/weird-dalle-prompts.txt diff --git a/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml new file mode 100644 index 0000000..7ca45b8 --- /dev/null +++ b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml @@ -0,0 +1,175 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + low_scale_key: "LR_image" # TODO: adapt + linear_start: 0.001 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "image" + #first_stage_key: "jpg" # TODO: use this later + cond_stage_key: "caption" + #cond_stage_key: "txt" # TODO: use this later + image_size: 64 + channels: 16 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.22765929 # magic number + + low_scale_config: + target: ldm.modules.encoders.modules.LowScaleEncoder + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + max_noise_level: 250 + output_size: 64 + model_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + num_classes: 1000 # timesteps for noise conditoining + image_size: 64 # not really needed + in_channels: 20 + out_channels: 16 + model_channels: 32 # TODO: more + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 16 + monitor: val/rec_loss + ckpt_path: "models/first_stage_models/kl-f16/model.ckpt" + ddconfig: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ 16 ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +#data: +# target: ldm.data.laion.WebDataModuleFromConfig +# params: +# tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/" +# batch_size: 4 +# num_workers: 4 +# multinode: True +# min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution +# train: +# shards: '{000000..231317}.tar -' +# shuffle: 10000 +# image_key: jpg +# image_transforms: +# - target: torchvision.transforms.Resize +# params: +# size: 1024 +# interpolation: 3 +# - target: torchvision.transforms.RandomCrop +# params: +# size: 1024 +# +# # NOTE use enough shards to avoid empty validation loops in workers +# validation: +# shards: '{231318..231349}.tar -' +# shuffle: 0 +# image_key: jpg +# image_transforms: +# - target: torchvision.transforms.Resize +# params: +# size: 1024 +# interpolation: 3 +# - target: torchvision.transforms.CenterCrop +# params: +# size: 1024 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 7 + wrap: false + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 1024 + downscale_f: 4 + degradation: "cv_nearest" + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 10 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + benchmark: True + # val_check_interval: 5000000 # really sorry # TODO: bring back in + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py index 1c473f9..6623196 100644 --- a/ldm/data/imagenet.py +++ b/ldm/data/imagenet.py @@ -368,7 +368,7 @@ class ImageNetSR(Dataset): example["image"] = (image/127.5 - 1.0).astype(np.float32) example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) - + example["caption"] = example["human_label"] # dummy caption return example diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 9ca6ff6..8846cab 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -701,7 +701,7 @@ class LatentDiffusion(DDPM): @torch.no_grad() def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None): + cond_key=None, return_original_cond=False, bs=None, return_x=False): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -746,6 +746,8 @@ class LatentDiffusion(DDPM): if return_first_stage_outputs: xrec = self.decode_first_stage(z) out.extend([x, xrec]) + if return_x: + out.extend([x]) if return_original_cond: out.append(xc) return out @@ -1416,9 +1418,9 @@ class DiffusionWrapper(pl.LightningModule): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': @@ -1431,6 +1433,11 @@ class DiffusionWrapper(pl.LightningModule): xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) elif self.conditioning_key == 'adm': cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) @@ -1440,6 +1447,34 @@ class DiffusionWrapper(pl.LightningModule): return out +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None): + z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs) + x_low = batch[self.low_scale_key] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + return z, all_conds + + # TODO log it + + class Layout2ImgDiffusion(LatentDiffusion): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 68260c3..a87f6a0 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn +import numpy as np from functools import partial from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from ldm.util import default class AbstractEncoder(nn.Module): @@ -201,6 +203,60 @@ class SpatialRescaler(nn.Module): return self(x) +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like + + +class LowScaleEncoder(nn.Module): + def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, + linear_end=linear_end) + self.out_size = output_size + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + z = self.model.encode(x).sample() + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + z = self.q_sample(z, noise_level) + #z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode + z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + return z, noise_level + + if __name__ == "__main__": from ldm.util import count_params sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] diff --git a/scripts/prompts/weird-dalle-prompts.txt b/scripts/prompts/weird-dalle-prompts.txt new file mode 100644 index 0000000..39ebf04 --- /dev/null +++ b/scripts/prompts/weird-dalle-prompts.txt @@ -0,0 +1,12 @@ +# TODO, check out Twitter. +Darth Vader at Woodstock (1969) +Bunny Vikings +The Demogorgon from Stranger Thinhs holding a basketball +Hamster in my microwave +a courtroom sketch of a Ford Transit van +PS1 Hagrid ad MCDonalds +Karl Marx in KFC Logo +Moai Statue giving a TED talk +wahing machine trail cam +minions at cross burning +Hindenburg disaster in Fortnite \ No newline at end of file