diff --git a/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml new file mode 100644 index 0000000..9a45eed --- /dev/null +++ b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml @@ -0,0 +1,166 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + low_scale_key: "lr" + linear_start: 0.001 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 16 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.22765929 # magic number + + low_scale_config: + target: ldm.modules.encoders.modules.LowScaleEncoder + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + max_noise_level: 250 + output_size: 64 + model_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + num_classes: 1000 # timesteps for noise conditoining + image_size: 64 # not really needed + in_channels: 20 + out_channels: 16 + model_channels: 32 # TODO: more + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 16 + monitor: val/rec_loss + ckpt_path: "models/first_stage_models/kl-f16/model.ckpt" + ddconfig: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ 16 ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/" + batch_size: 4 + num_workers: 1 + train: + shards: '{00000..17279}.tar -' + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 1024 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 1024 + postprocess: + target: ldm.data.laion.AddLR + params: + factor: 4 + + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: '{17280..17535}.tar -' + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 1024 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 1024 + postprocess: + target: ldm.data.laion.AddLR + params: + factor: 4 + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 10 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + benchmark: True + val_check_interval: 5000000 # really sorry + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 1eaf08f..f23ebb1 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -138,6 +138,11 @@ class WebDataModuleFromConfig(pl.LightningDataModule): img_key = dataset_config.get('image_key', 'jpeg') transform_dict.update({img_key: image_transforms}) + if 'postprocess' in dataset_config: + postprocess = instantiate_from_config(dataset_config['postprocess']) + else: + postprocess = None + shuffle = dataset_config.get('shuffle', 0) shardshuffle = shuffle > 0 @@ -156,8 +161,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule): .decode('pil', handler=wds.warn_and_continue) .select(self.filter_size) .map_dict(**transform_dict, handler=wds.warn_and_continue) + ) + if postprocess is not None: + dset = dset.map(postprocess) + dset = (dset .batched(self.batch_size, partial=False, - collation_fn=dict_collation_fn) + collation_fn=dict_collation_fn) ) loader = wds.WebLoader(dset, batch_size=None, shuffle=False, @@ -189,6 +198,29 @@ class WebDataModuleFromConfig(pl.LightningDataModule): return self.make_loader(self.test, train=False) +from ldm.modules.image_degradation import degradation_fn_bsr_light + +class AddLR(object): + def __init__(self, factor): + self.factor = factor + + def pt2np(self, x): + x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x)/127.5-1.0 + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = self.pt2np(sample['jpg']) + x = degradation_fn_bsr_light(x, sf=self.factor)['image'] + x = self.np2pt(x) + sample['lr'] = x + return sample + + def example00(): url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" dataset = wds.WebDataset(url) @@ -270,7 +302,8 @@ if __name__ == "__main__": from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator - config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") + #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") + config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml") datamod = WebDataModuleFromConfig(**config["data"]["params"]) dataloader = datamod.train_dataloader()