diff --git a/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-768.yaml b/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-768.yaml new file mode 100644 index 0000000..184905a --- /dev/null +++ b/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-768.yaml @@ -0,0 +1,129 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.001 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 48 + channels: 16 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.22765929 # magic number + + ckpt_path: # TODO: add + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 48 # not really needed + in_channels: 16 + out_channels: 16 + model_channels: 320 # TODO: scale model here + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 16 + monitor: val/rec_loss + ckpt_path: "models/first_stage_models/kl-f16/model.ckpt" + ddconfig: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ 16 ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/" + batch_size: 10 + num_workers: 4 + multinode: True + min_size: 384 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution + train: + shards: '{000000..231317}.tar -' + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 768 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 768 + + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: '{231318..231349}.tar -' + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 768 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 768 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 2 diff --git a/scripts/txt2img.py b/scripts/txt2img.py index d9ec628..e161501 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -7,6 +7,7 @@ from tqdm import tqdm, trange from itertools import islice from einops import rearrange from torchvision.utils import make_grid +import time from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler @@ -63,6 +64,12 @@ if __name__ == "__main__": help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save indiviual samples. For speed measurements.", + ) + parser.add_argument( "--ddim_steps", type=int, @@ -103,6 +110,19 @@ if __name__ == "__main__": help="image width, in pixel space", ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( "--n_samples", type=int, @@ -184,6 +204,7 @@ if __name__ == "__main__": with torch.no_grad(): with model.ema_scope(): + tic = time.time() for n in trange(opt.n_iter, desc="Sampling"): all_samples = list() for prompts in tqdm(data, desc="data"): @@ -193,7 +214,7 @@ if __name__ == "__main__": if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) - shape = [4, opt.H//8, opt.W//8] + shape = [opt.C, opt.H//opt.f, opt.W//opt.f] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, conditioning=c, batch_size=opt.n_samples, @@ -207,10 +228,11 @@ if __name__ == "__main__": x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 + if not opt.skip_save: + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 all_samples.append(x_samples_ddim) if not opt.skip_grid: @@ -224,4 +246,8 @@ if __name__ == "__main__": Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) grid_count += 1 - print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." + f" \nEnjoy.")