From e2a6bee13cf1532797511b55a50f5884ccc30713 Mon Sep 17 00:00:00 2001 From: rromb Date: Wed, 6 Jul 2022 23:44:26 +0200 Subject: [PATCH] add further nerf-attention --- ...2img-2B-clip-encoder-high-res-512-dev.yaml | 9 +-- ldm/modules/diffusionmodules/openaimodel.py | 60 +++++++++++-------- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml b/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml index b13c6a4..18673ad 100644 --- a/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml +++ b/configs/stable-diffusion/txt2img-2B-clip-encoder-high-res-512-dev.yaml @@ -32,11 +32,12 @@ model: image_size: 64 # unused in_channels: 4 out_channels: 4 - model_channels: 352 - attention_resolutions: [ 8, 4, 2 ] - num_res_blocks: [ 2, 2, 2, 6 ] + model_channels: 384 + attention_resolutions: [ 8, 4, 2, 1 ] + num_res_blocks: [ 2, 2, 2, 5 ] channel_mult: [ 1, 2, 4, 4 ] - disable_self_attentions: [ True, True, True, False ] # converts the self-attention to a cross-attention layer if true + disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true + num_attention_blocks: [1, 1, 1, 3] num_heads: 8 use_spatial_transformer: True transformer_depth: 1 diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index d5fc1c9..6b994cc 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -467,7 +467,8 @@ class UNetModel(nn.Module): context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, - disable_self_attentions=None + disable_self_attentions=None, + num_attention_blocks=None ): super().__init__() if use_spatial_transformer: @@ -503,6 +504,13 @@ class UNetModel(nn.Module): if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -538,7 +546,7 @@ class UNetModel(nn.Module): ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): - for _ in range(self.num_res_blocks[level]): + for nr in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, @@ -564,18 +572,20 @@ class UNetModel(nn.Module): disabled_sa = disable_self_attentions[level] else: disabled_sa = False - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) ) - ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -670,18 +680,20 @@ class UNetModel(nn.Module): disabled_sa = disable_self_attentions[level] else: disabled_sa = False - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) ) - ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append(