diff --git a/configs/stable-diffusion/txt2img-ldm-vae-f8.yaml b/configs/stable-diffusion/txt2img-ldm-vae-f8.yaml new file mode 100644 index 0000000..9cb8e8b --- /dev/null +++ b/configs/stable-diffusion/txt2img-ldm-vae-f8.yaml @@ -0,0 +1,130 @@ +model: + base_learning_rate: 1.0e-04 # TODO: run with scale_lr False + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 128 # 320 # TODO increase + attention_resolutions: [ 4, 2, 1 ] # is equal to fixed spatial resolution: 32 , 16 , 8 + num_res_blocks: 2 + channel_mult: [ 1,2,4,4 ] + #num_head_channels: 32 + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1280 + use_checkpoint: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ckpt_path: "/home/robin/projects/latent-diffusion/models/first_stage_models/kl-f8/model.ckpt" + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 1280 + n_layer: 3 #32 # TODO: increase + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/" + batch_size: 60 + num_workers: 4 + n_nodes: 2 # TODO: runs with two gpus + train: + shards: '{000000..000010}.tar -' # TODO: wild guess, change + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + + shuffle: 5000 + n_examples: 16519100 # TODO: find out + validation: + shards: '{000011..000012}.tar -' # TODO: wild guess, change + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + + shuffle: 0 + n_examples: 60000 # TODO: find out + val_num_workers: 2 + + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 # 5000 + max_images: 8 + increase_log_steps: False + log_first_step: True + + + trainer: + replace_sampler_ddp: False + benchmark: True + val_check_interval: 20000 # every 20k training steps + num_sanity_val_steps: 0 \ No newline at end of file diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 5d87d5d..05fc245 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -2,7 +2,178 @@ import webdataset as wds from PIL import Image import io import os +import torchvision +from PIL import Image +import glob +import random +import numpy as np +import pytorch_lightning as pl from tqdm import tqdm +from omegaconf import OmegaConf +from einops import rearrange +import torch + + +from ldm.util import instantiate_from_config + + +def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): + """Take a list of samples (as dictionary) and create a batch, preserving the keys. + If `tensors` is True, `ndarray` objects are combined into + tensor batches. + :param dict samples: list of samples + :param bool tensors: whether to turn lists of ndarrays into a single ndarray + :returns: single sample consisting of a batch + :rtype: dict + """ + batched = {key: [] for key in samples[0]} + # assert isinstance(samples[0][first_key], (list, tuple)), type(samples[first_key]) + + for s in samples: + [batched[key].append(s[key]) for key in batched] + + + result = {} + for key in batched: + if isinstance(batched[key][0], (int, float)): + if combine_scalars: + result[key] = np.array(list(batched[key])) + elif isinstance(batched[key][0], torch.Tensor): + if combine_tensors: + # import torch + + result[key] = torch.stack(list(batched[key])) + elif isinstance(batched[key][0], np.ndarray): + if combine_tensors: + result[key] = np.array(list(batched[key])) + else: + result[key] = list(batched[key]) + # result.append(b) + return result + + +class WebDataModuleFromConfig(pl.LightningDataModule): + def __init__(self, tar_base, batch_size, train=None, validation=None, + test=None, num_workers=4, load_ddp=True, n_nodes=1, + **kwargs): + super().__init__(self) + print(f'Setting tar base to {tar_base}') + self.tar_base = tar_base + self.batch_size = batch_size + self.num_workers = num_workers + self.train = train + self.validation = validation + self.test = test + self.load_ddp = load_ddp + self.multinode = n_nodes > 1 + self.n_nodes = n_nodes # n gpu ?? + + def make_loader(self, dataset_config, train=True): + if 'image_transforms' in dataset_config: + image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] + else: + image_transforms = [] + + image_transforms.extend([torchvision.transforms.ToTensor(), + torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = torchvision.transforms.Compose(image_transforms) + + if 'transforms' in dataset_config: + transforms_config = OmegaConf.to_container(dataset_config.transforms) + else: + transforms_config = dict() + + transform_dict = {dkey: load_partial_from_config(transforms_config[dkey]) if transforms_config[ + dkey] != 'identity' else identity + for dkey in transforms_config} + img_key = dataset_config.get('image_key', 'jpeg') + transform_dict.update({img_key: image_transforms}) + + shuffle = dataset_config.get('shuffle', 0) + + # TODO fid strategy when n exmples not known beforehand + n_examples = dataset_config.get('n_examples', 1e6) // self.n_nodes + + shards_to_load = dataset_config.shards + dset_name = 'unknown' + if isinstance(shards_to_load, str): + print(f'Loading tars based on the string {shards_to_load}') + tars = os.path.join(self.tar_base, shards_to_load) + start_shard_id, end_shard_id = dataset_config.shards.split('{')[-1].split('}')[0].split('..') + n_shards = int(end_shard_id) - int(start_shard_id) + 1 + dset_name = dataset_config.shards.split('-')[0] + elif isinstance(shards_to_load, int): + print(f'Creating tar list, max shard is {shards_to_load}') + try: + tars = [tf for tf in natsorted(glob(os.path.join(self.tar_base, '*.tar'))) if + int(tf.split('/')[-1].split('.')[0]) < shards_to_load] + n_shards = len(tars) + random.shuffle(tars) + + except ValueError as e: + print('tarfile names should follow the pattern .tar . Check names of the files') + raise e + else: + raise ValueError( + 'shards should be either a string containing consecutive shards or an int defining the max shard number') + + print(f'Got {n_shards} shard files in datafolder for {"training" if train else "validation"}') + + # if self.num_workers > 0: + # assert n_shards % self.num_workers == 0 , f'Number of workers which is {self.num_workers} does not evenly divide number of shards which is {n_shards}' + print(f'Loading webdataset based dataloader based on {n_shards} of {dset_name} dataset.') + + # start creating the dataset + nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only + epoch_length = n_examples // (self.batch_size) + + dset = wds.WebDataset(tars, nodesplitter=nodesplitter).shuffle(shuffle) + + with_epoch_args = {'nsamples': n_examples, 'nbatches': epoch_length} + + if 'filters' in dataset_config: + for stage in tqdm(dataset_config.filters, + desc=f'Applying the following filters: {[f for f in dataset_config.filters]}'): + f = getattr(dset, stage) + dset = f(dset, *dataset_config.filters[stage].args, + **dataset_config.filters[stage].get('kwargs', dict())) + + print(f'Dataset holding {len(dset.pipeline[0].urls)} shards') + + dset = (dset + .decode('pil') + # .to_tuple("jpg;png;jpeg pickle cls hls") + # .map_tuple(image_transforms,load_partial_from_config(nns_transform) if 'target' in nns_transform else identity,identity,identity) + .map_dict(**transform_dict) + .repeat() + .batched(self.batch_size, partial=False, + collation_fn=dict_collation_fn) + .with_length(n_examples) + .with_epoch(**with_epoch_args) + ) + + loader = wds.WebLoader(dset, batch_size=None, shuffle=False, + num_workers=self.num_workers) + + return loader, n_examples + + def train_dataloader(self): + assert self.train is not None + loader, dset_size = self.make_loader(self.train) + # if self.load_ddp: + # loader = loader.ddp_equalize(dset_size // self.batch_size) + return loader + + def val_dataloader(self): + assert self.train is not None + loader, _ = self.make_loader(self.validation, train=False) + return loader + + def test_dataloader(self): + assert self.train is not None + loader, _ = self.make_loader(self.test, train=False) + return loader + if __name__ == "__main__": url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" @@ -18,7 +189,7 @@ if __name__ == "__main__": image = Image.open(io.BytesIO(example["jpg"])) outdir = "tmp" os.makedirs(outdir, exist_ok=True) - image.save(os.path.join(outdir, example["__key__"]+".png")) + image.save(os.path.join(outdir, example["__key__"] + ".png")) def load_example(example): diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index bbedd04..49b8ecc 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -664,7 +664,7 @@ class LatentDiffusion(DDPM): if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox']: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: xc = batch[cond_key] elif cond_key == 'class_label': xc = batch @@ -762,66 +762,6 @@ class LatentDiffusion(DDPM): else: return self.first_stage_model.decode(z) - # same as above but without decorator - def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - @torch.no_grad() def encode_first_stage(self, x): if hasattr(self, "split_input_params"): @@ -1268,8 +1208,8 @@ class LatentDiffusion(DDPM): if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc - elif self.cond_stage_key in ["caption"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key]) log["conditioning"] = xc elif self.cond_stage_key == 'class_label': xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) diff --git a/main.py b/main.py index e8e18c1..939d9c1 100644 --- a/main.py +++ b/main.py @@ -667,8 +667,11 @@ if __name__ == "__main__": data.prepare_data() data.setup() print("#### Data #####") - for k in data.datasets: - print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + try: + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + except: + print("datasets not yet initialized.") # configure learning rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate