Merge remote-tracking branch 'origin/main'
This commit is contained in:
		
						commit
						9300c0ccfc
					
				
					 3 changed files with 220 additions and 30 deletions
				
			
		| 
						 | 
					@ -0,0 +1,150 @@
 | 
				
			||||||
 | 
					model:
 | 
				
			||||||
 | 
					  base_learning_rate: 1.0e-04
 | 
				
			||||||
 | 
					  target: ldm.models.diffusion.ddpm.LatentDiffusion
 | 
				
			||||||
 | 
					  params:
 | 
				
			||||||
 | 
					    #ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt"
 | 
				
			||||||
 | 
					    linear_start: 0.00085
 | 
				
			||||||
 | 
					    linear_end: 0.0120
 | 
				
			||||||
 | 
					    num_timesteps_cond: 1
 | 
				
			||||||
 | 
					    log_every_t: 200
 | 
				
			||||||
 | 
					    timesteps: 1000
 | 
				
			||||||
 | 
					    first_stage_key: "jpg"
 | 
				
			||||||
 | 
					    cond_stage_key: "txt"
 | 
				
			||||||
 | 
					    image_size: 64
 | 
				
			||||||
 | 
					    channels: 4
 | 
				
			||||||
 | 
					    cond_stage_trainable: false   # Note: different from the one we trained before
 | 
				
			||||||
 | 
					    conditioning_key: crossattn
 | 
				
			||||||
 | 
					    monitor: val/loss_simple_ema
 | 
				
			||||||
 | 
					    scale_factor: 0.18215
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scheduler_config: # 10000 warmup steps
 | 
				
			||||||
 | 
					      target: ldm.lr_scheduler.LambdaLinearScheduler
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        warm_up_steps: [ 10000 ]
 | 
				
			||||||
 | 
					        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
 | 
				
			||||||
 | 
					        f_start: [ 1.e-6 ]
 | 
				
			||||||
 | 
					        f_max: [ 1. ]
 | 
				
			||||||
 | 
					        f_min: [ 1. ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    unet_config:
 | 
				
			||||||
 | 
					      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        image_size: 64 # unused
 | 
				
			||||||
 | 
					        in_channels: 4
 | 
				
			||||||
 | 
					        out_channels: 4
 | 
				
			||||||
 | 
					        model_channels: 384
 | 
				
			||||||
 | 
					        attention_resolutions: [ 8, 4, 2, 1 ]
 | 
				
			||||||
 | 
					        num_res_blocks: [ 2, 2, 2, 5 ]
 | 
				
			||||||
 | 
					        channel_mult: [ 1, 2, 4, 4 ]
 | 
				
			||||||
 | 
					        disable_self_attentions: [ False, False, False, False ]  # converts the self-attention to a cross-attention layer if true
 | 
				
			||||||
 | 
					        num_attention_blocks: [1, 1, 1, 3]
 | 
				
			||||||
 | 
					        num_heads: 8
 | 
				
			||||||
 | 
					        use_spatial_transformer: True
 | 
				
			||||||
 | 
					        transformer_depth: 1
 | 
				
			||||||
 | 
					        context_dim: 768
 | 
				
			||||||
 | 
					        use_checkpoint: True
 | 
				
			||||||
 | 
					        legacy: False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    first_stage_config:
 | 
				
			||||||
 | 
					      target: ldm.models.autoencoder.AutoencoderKL
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        embed_dim: 4
 | 
				
			||||||
 | 
					        monitor: val/rec_loss
 | 
				
			||||||
 | 
					        ddconfig:
 | 
				
			||||||
 | 
					          double_z: true
 | 
				
			||||||
 | 
					          z_channels: 4
 | 
				
			||||||
 | 
					          resolution: 256
 | 
				
			||||||
 | 
					          in_channels: 3
 | 
				
			||||||
 | 
					          out_ch: 3
 | 
				
			||||||
 | 
					          ch: 128
 | 
				
			||||||
 | 
					          ch_mult:
 | 
				
			||||||
 | 
					          - 1
 | 
				
			||||||
 | 
					          - 2
 | 
				
			||||||
 | 
					          - 4
 | 
				
			||||||
 | 
					          - 4
 | 
				
			||||||
 | 
					          num_res_blocks: 2
 | 
				
			||||||
 | 
					          attn_resolutions: []
 | 
				
			||||||
 | 
					          dropout: 0.0
 | 
				
			||||||
 | 
					        lossconfig:
 | 
				
			||||||
 | 
					          target: torch.nn.Identity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cond_stage_config:
 | 
				
			||||||
 | 
					      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					data:
 | 
				
			||||||
 | 
					  target: main.DataModuleFromConfig
 | 
				
			||||||
 | 
					  params:
 | 
				
			||||||
 | 
					    batch_size: 1
 | 
				
			||||||
 | 
					    num_workers: 4
 | 
				
			||||||
 | 
					    wrap: false
 | 
				
			||||||
 | 
					    train:
 | 
				
			||||||
 | 
					      target: ldm.data.dummy.DummyData
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        length: 20000
 | 
				
			||||||
 | 
					        size: [512, 512, 3]
 | 
				
			||||||
 | 
					    validation:
 | 
				
			||||||
 | 
					      target: ldm.data.dummy.DummyData
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        length: 10000
 | 
				
			||||||
 | 
					        size: [512, 512, 3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#data:
 | 
				
			||||||
 | 
					#  target: ldm.data.laion.WebDataModuleFromConfig
 | 
				
			||||||
 | 
					#  params:
 | 
				
			||||||
 | 
					#    tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
 | 
				
			||||||
 | 
					#    batch_size: 4
 | 
				
			||||||
 | 
					#    num_workers: 4
 | 
				
			||||||
 | 
					#    multinode: True
 | 
				
			||||||
 | 
					#    train:
 | 
				
			||||||
 | 
					#      shards: '{00000..17279}.tar -'
 | 
				
			||||||
 | 
					#      shuffle: 10000
 | 
				
			||||||
 | 
					#      image_key: jpg
 | 
				
			||||||
 | 
					#      image_transforms:
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.Resize
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 512
 | 
				
			||||||
 | 
					#          interpolation: 3
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.RandomCrop
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 512
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    # NOTE use enough shards to avoid empty validation loops in workers
 | 
				
			||||||
 | 
					#    validation:
 | 
				
			||||||
 | 
					#      shards: '{17280..17535}.tar -'
 | 
				
			||||||
 | 
					#      shuffle: 0
 | 
				
			||||||
 | 
					#      image_key: jpg
 | 
				
			||||||
 | 
					#      image_transforms:
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.Resize
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 512
 | 
				
			||||||
 | 
					#          interpolation: 3
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.CenterCrop
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 512
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					lightning:
 | 
				
			||||||
 | 
					  callbacks:
 | 
				
			||||||
 | 
					    image_logger:
 | 
				
			||||||
 | 
					      target: main.ImageLogger
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        batch_frequency: 5000
 | 
				
			||||||
 | 
					        max_images: 4
 | 
				
			||||||
 | 
					        increase_log_steps: False
 | 
				
			||||||
 | 
					        log_first_step: False
 | 
				
			||||||
 | 
					        log_images_kwargs:
 | 
				
			||||||
 | 
					          use_ema_scope: False
 | 
				
			||||||
 | 
					          inpaint: False
 | 
				
			||||||
 | 
					          plot_progressive_rows: False
 | 
				
			||||||
 | 
					          plot_diffusion_rows: False
 | 
				
			||||||
 | 
					          N: 4
 | 
				
			||||||
 | 
					          unconditional_guidance_scale: 3.0
 | 
				
			||||||
 | 
					          unconditional_guidance_label: [""]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  trainer:
 | 
				
			||||||
 | 
					    #replace_sampler_ddp: False
 | 
				
			||||||
 | 
					    benchmark: True
 | 
				
			||||||
 | 
					    val_check_interval: 1000 # TODO: 1e10 # really sorry
 | 
				
			||||||
 | 
					    num_sanity_val_steps: 0
 | 
				
			||||||
 | 
					    accumulate_grad_batches: 2
 | 
				
			||||||
| 
						 | 
					@ -194,9 +194,12 @@ class CrossAttention(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BasicTransformerBlock(nn.Module):
 | 
					class BasicTransformerBlock(nn.Module):
 | 
				
			||||||
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
 | 
					    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
 | 
				
			||||||
 | 
					                 disable_self_attn=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
 | 
					        self.disable_self_attn = disable_self_attn
 | 
				
			||||||
 | 
					        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
 | 
				
			||||||
 | 
					                                    context_dim=context_dim if self.disable_self_attn else None)  # is a self-attention if not self.disable_self_attn
 | 
				
			||||||
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
 | 
					        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
 | 
				
			||||||
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
 | 
					        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
 | 
				
			||||||
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
 | 
					                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
 | 
				
			||||||
| 
						 | 
					@ -209,7 +212,7 @@ class BasicTransformerBlock(nn.Module):
 | 
				
			||||||
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
 | 
					        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _forward(self, x, context=None):
 | 
					    def _forward(self, x, context=None):
 | 
				
			||||||
        x = self.attn1(self.norm1(x)) + x
 | 
					        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
 | 
				
			||||||
        x = self.attn2(self.norm2(x), context=context) + x
 | 
					        x = self.attn2(self.norm2(x), context=context) + x
 | 
				
			||||||
        x = self.ff(self.norm3(x)) + x
 | 
					        x = self.ff(self.norm3(x)) + x
 | 
				
			||||||
        return x
 | 
					        return x
 | 
				
			||||||
| 
						 | 
					@ -224,7 +227,8 @@ class SpatialTransformer(nn.Module):
 | 
				
			||||||
    Finally, reshape to image
 | 
					    Finally, reshape to image
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, in_channels, n_heads, d_head,
 | 
					    def __init__(self, in_channels, n_heads, d_head,
 | 
				
			||||||
                 depth=1, dropout=0., context_dim=None):
 | 
					                 depth=1, dropout=0., context_dim=None,
 | 
				
			||||||
 | 
					                 disable_self_attn=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.in_channels = in_channels
 | 
					        self.in_channels = in_channels
 | 
				
			||||||
        inner_dim = n_heads * d_head
 | 
					        inner_dim = n_heads * d_head
 | 
				
			||||||
| 
						 | 
					@ -237,7 +241,8 @@ class SpatialTransformer(nn.Module):
 | 
				
			||||||
                                 padding=0)
 | 
					                                 padding=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.transformer_blocks = nn.ModuleList(
 | 
					        self.transformer_blocks = nn.ModuleList(
 | 
				
			||||||
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
 | 
					            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
 | 
				
			||||||
 | 
					                                   disable_self_attn=disable_self_attn)
 | 
				
			||||||
                for d in range(depth)]
 | 
					                for d in range(depth)]
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,7 @@ from ldm.modules.diffusionmodules.util import (
 | 
				
			||||||
    timestep_embedding,
 | 
					    timestep_embedding,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from ldm.modules.attention import SpatialTransformer
 | 
					from ldm.modules.attention import SpatialTransformer
 | 
				
			||||||
 | 
					from ldm.util import exists
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# dummy replace
 | 
					# dummy replace
 | 
				
			||||||
| 
						 | 
					@ -466,6 +467,8 @@ class UNetModel(nn.Module):
 | 
				
			||||||
        context_dim=None,                 # custom transformer support
 | 
					        context_dim=None,                 # custom transformer support
 | 
				
			||||||
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
 | 
					        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
 | 
				
			||||||
        legacy=True,
 | 
					        legacy=True,
 | 
				
			||||||
 | 
					        disable_self_attentions=None,
 | 
				
			||||||
 | 
					        num_attention_blocks=None
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        if use_spatial_transformer:
 | 
					        if use_spatial_transformer:
 | 
				
			||||||
| 
						 | 
					@ -490,7 +493,25 @@ class UNetModel(nn.Module):
 | 
				
			||||||
        self.in_channels = in_channels
 | 
					        self.in_channels = in_channels
 | 
				
			||||||
        self.model_channels = model_channels
 | 
					        self.model_channels = model_channels
 | 
				
			||||||
        self.out_channels = out_channels
 | 
					        self.out_channels = out_channels
 | 
				
			||||||
        self.num_res_blocks = num_res_blocks
 | 
					        if isinstance(num_res_blocks, int):
 | 
				
			||||||
 | 
					            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if len(num_res_blocks) != len(channel_mult):
 | 
				
			||||||
 | 
					                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
 | 
				
			||||||
 | 
					                                 "as a list/tuple (per-level) with the same length as channel_mult")
 | 
				
			||||||
 | 
					            self.num_res_blocks = num_res_blocks
 | 
				
			||||||
 | 
					        #self.num_res_blocks = num_res_blocks
 | 
				
			||||||
 | 
					        if disable_self_attentions is not None:
 | 
				
			||||||
 | 
					            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
 | 
				
			||||||
 | 
					            assert len(disable_self_attentions) == len(channel_mult)
 | 
				
			||||||
 | 
					        if num_attention_blocks is not None:
 | 
				
			||||||
 | 
					            assert len(num_attention_blocks) == len(self.num_res_blocks)
 | 
				
			||||||
 | 
					            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
 | 
				
			||||||
 | 
					            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
 | 
				
			||||||
 | 
					                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
 | 
				
			||||||
 | 
					                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
 | 
				
			||||||
 | 
					                  f"attention will still not be set.")  # todo: convert to warning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.attention_resolutions = attention_resolutions
 | 
					        self.attention_resolutions = attention_resolutions
 | 
				
			||||||
        self.dropout = dropout
 | 
					        self.dropout = dropout
 | 
				
			||||||
        self.channel_mult = channel_mult
 | 
					        self.channel_mult = channel_mult
 | 
				
			||||||
| 
						 | 
					@ -525,7 +546,7 @@ class UNetModel(nn.Module):
 | 
				
			||||||
        ch = model_channels
 | 
					        ch = model_channels
 | 
				
			||||||
        ds = 1
 | 
					        ds = 1
 | 
				
			||||||
        for level, mult in enumerate(channel_mult):
 | 
					        for level, mult in enumerate(channel_mult):
 | 
				
			||||||
            for _ in range(num_res_blocks):
 | 
					            for nr in range(self.num_res_blocks[level]):
 | 
				
			||||||
                layers = [
 | 
					                layers = [
 | 
				
			||||||
                    ResBlock(
 | 
					                    ResBlock(
 | 
				
			||||||
                        ch,
 | 
					                        ch,
 | 
				
			||||||
| 
						 | 
					@ -547,17 +568,24 @@ class UNetModel(nn.Module):
 | 
				
			||||||
                    if legacy:
 | 
					                    if legacy:
 | 
				
			||||||
                        #num_heads = 1
 | 
					                        #num_heads = 1
 | 
				
			||||||
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | 
					                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | 
				
			||||||
                    layers.append(
 | 
					                    if exists(disable_self_attentions):
 | 
				
			||||||
                        AttentionBlock(
 | 
					                        disabled_sa = disable_self_attentions[level]
 | 
				
			||||||
                            ch,
 | 
					                    else:
 | 
				
			||||||
                            use_checkpoint=use_checkpoint,
 | 
					                        disabled_sa = False
 | 
				
			||||||
                            num_heads=num_heads,
 | 
					
 | 
				
			||||||
                            num_head_channels=dim_head,
 | 
					                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
 | 
				
			||||||
                            use_new_attention_order=use_new_attention_order,
 | 
					                        layers.append(
 | 
				
			||||||
                        ) if not use_spatial_transformer else SpatialTransformer(
 | 
					                            AttentionBlock(
 | 
				
			||||||
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 | 
					                                ch,
 | 
				
			||||||
 | 
					                                use_checkpoint=use_checkpoint,
 | 
				
			||||||
 | 
					                                num_heads=num_heads,
 | 
				
			||||||
 | 
					                                num_head_channels=dim_head,
 | 
				
			||||||
 | 
					                                use_new_attention_order=use_new_attention_order,
 | 
				
			||||||
 | 
					                            ) if not use_spatial_transformer else SpatialTransformer(
 | 
				
			||||||
 | 
					                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
 | 
				
			||||||
 | 
					                                disable_self_attn=disabled_sa
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                self.input_blocks.append(TimestepEmbedSequential(*layers))
 | 
					                self.input_blocks.append(TimestepEmbedSequential(*layers))
 | 
				
			||||||
                self._feature_size += ch
 | 
					                self._feature_size += ch
 | 
				
			||||||
                input_block_chans.append(ch)
 | 
					                input_block_chans.append(ch)
 | 
				
			||||||
| 
						 | 
					@ -609,7 +637,7 @@ class UNetModel(nn.Module):
 | 
				
			||||||
                num_heads=num_heads,
 | 
					                num_heads=num_heads,
 | 
				
			||||||
                num_head_channels=dim_head,
 | 
					                num_head_channels=dim_head,
 | 
				
			||||||
                use_new_attention_order=use_new_attention_order,
 | 
					                use_new_attention_order=use_new_attention_order,
 | 
				
			||||||
            ) if not use_spatial_transformer else SpatialTransformer(
 | 
					            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
 | 
				
			||||||
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 | 
					                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 | 
				
			||||||
                        ),
 | 
					                        ),
 | 
				
			||||||
            ResBlock(
 | 
					            ResBlock(
 | 
				
			||||||
| 
						 | 
					@ -625,7 +653,7 @@ class UNetModel(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.output_blocks = nn.ModuleList([])
 | 
					        self.output_blocks = nn.ModuleList([])
 | 
				
			||||||
        for level, mult in list(enumerate(channel_mult))[::-1]:
 | 
					        for level, mult in list(enumerate(channel_mult))[::-1]:
 | 
				
			||||||
            for i in range(num_res_blocks + 1):
 | 
					            for i in range(self.num_res_blocks[level] + 1):
 | 
				
			||||||
                ich = input_block_chans.pop()
 | 
					                ich = input_block_chans.pop()
 | 
				
			||||||
                layers = [
 | 
					                layers = [
 | 
				
			||||||
                    ResBlock(
 | 
					                    ResBlock(
 | 
				
			||||||
| 
						 | 
					@ -648,18 +676,25 @@ class UNetModel(nn.Module):
 | 
				
			||||||
                    if legacy:
 | 
					                    if legacy:
 | 
				
			||||||
                        #num_heads = 1
 | 
					                        #num_heads = 1
 | 
				
			||||||
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | 
					                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | 
				
			||||||
                    layers.append(
 | 
					                    if exists(disable_self_attentions):
 | 
				
			||||||
                        AttentionBlock(
 | 
					                        disabled_sa = disable_self_attentions[level]
 | 
				
			||||||
                            ch,
 | 
					                    else:
 | 
				
			||||||
                            use_checkpoint=use_checkpoint,
 | 
					                        disabled_sa = False
 | 
				
			||||||
                            num_heads=num_heads_upsample,
 | 
					
 | 
				
			||||||
                            num_head_channels=dim_head,
 | 
					                    if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
 | 
				
			||||||
                            use_new_attention_order=use_new_attention_order,
 | 
					                        layers.append(
 | 
				
			||||||
                        ) if not use_spatial_transformer else SpatialTransformer(
 | 
					                            AttentionBlock(
 | 
				
			||||||
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 | 
					                                ch,
 | 
				
			||||||
 | 
					                                use_checkpoint=use_checkpoint,
 | 
				
			||||||
 | 
					                                num_heads=num_heads_upsample,
 | 
				
			||||||
 | 
					                                num_head_channels=dim_head,
 | 
				
			||||||
 | 
					                                use_new_attention_order=use_new_attention_order,
 | 
				
			||||||
 | 
					                            ) if not use_spatial_transformer else SpatialTransformer(
 | 
				
			||||||
 | 
					                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
 | 
				
			||||||
 | 
					                                disable_self_attn=disabled_sa
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    )
 | 
					                if level and i == self.num_res_blocks[level]:
 | 
				
			||||||
                if level and i == num_res_blocks:
 | 
					 | 
				
			||||||
                    out_ch = ch
 | 
					                    out_ch = ch
 | 
				
			||||||
                    layers.append(
 | 
					                    layers.append(
 | 
				
			||||||
                        ResBlock(
 | 
					                        ResBlock(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue