diff --git a/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml b/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml new file mode 100644 index 0000000..b13c6a4 --- /dev/null +++ b/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml @@ -0,0 +1,149 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + #ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 # unused + in_channels: 4 + out_channels: 4 + model_channels: 352 + attention_resolutions: [ 8, 4, 2 ] + num_res_blocks: [ 2, 2, 2, 6 ] + channel_mult: [ 1, 2, 4, 4 ] + disable_self_attentions: [ True, True, True, False ] # converts the self-attention to a cross-attention layer if true + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 1 + num_workers: 4 + wrap: false + train: + target: ldm.data.dummy.DummyData + params: + length: 20000 + size: [512, 512, 3] + validation: + target: ldm.data.dummy.DummyData + params: + length: 10000 + size: [512, 512, 3] + +#data: +# target: ldm.data.laion.WebDataModuleFromConfig +# params: +# tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/" +# batch_size: 4 +# num_workers: 4 +# multinode: True +# train: +# shards: '{00000..17279}.tar -' +# shuffle: 10000 +# image_key: jpg +# image_transforms: +# - target: torchvision.transforms.Resize +# params: +# size: 512 +# interpolation: 3 +# - target: torchvision.transforms.RandomCrop +# params: +# size: 512 +# +# # NOTE use enough shards to avoid empty validation loops in workers +# validation: +# shards: '{17280..17535}.tar -' +# shuffle: 0 +# image_key: jpg +# image_transforms: +# - target: torchvision.transforms.Resize +# params: +# size: 512 +# interpolation: 3 +# - target: torchvision.transforms.CenterCrop +# params: +# size: 512 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + #replace_sampler_ddp: False + benchmark: True + val_check_interval: 1000 # TODO: 1e10 # really sorry + num_sanity_val_steps: 0 + accumulate_grad_batches: 2 diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f0f99c4..124effb 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -194,9 +194,12 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none @@ -209,7 +212,7 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x @@ -224,7 +227,8 @@ class SpatialTransformer(nn.Module): Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None): + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head @@ -237,7 +241,8 @@ class SpatialTransformer(nn.Module): padding=0) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) for d in range(depth)] ) diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index fcf95d1..d5fc1c9 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -18,6 +18,7 @@ from ldm.modules.diffusionmodules.util import ( timestep_embedding, ) from ldm.modules.attention import SpatialTransformer +from ldm.util import exists # dummy replace @@ -466,6 +467,7 @@ class UNetModel(nn.Module): context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, + disable_self_attentions=None ): super().__init__() if use_spatial_transformer: @@ -490,7 +492,18 @@ class UNetModel(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - self.num_res_blocks = num_res_blocks + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult @@ -525,7 +538,7 @@ class UNetModel(nn.Module): ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): + for _ in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, @@ -547,6 +560,10 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False layers.append( AttentionBlock( ch, @@ -555,7 +572,8 @@ class UNetModel(nn.Module): num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -609,7 +627,7 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ), ResBlock( @@ -625,7 +643,7 @@ class UNetModel(nn.Module): self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): + for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( @@ -648,6 +666,10 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False layers.append( AttentionBlock( ch, @@ -656,10 +678,11 @@ class UNetModel(nn.Module): num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa ) ) - if level and i == num_res_blocks: + if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( ResBlock(