Merge remote-tracking branch 'origin/main'

This commit is contained in:
pesser 2022-07-06 23:06:41 +00:00
commit 9300c0ccfc
3 changed files with 220 additions and 30 deletions

View file

@ -0,0 +1,150 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
#ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64 # unused
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [ 8, 4, 2, 1 ]
num_res_blocks: [ 2, 2, 2, 5 ]
channel_mult: [ 1, 2, 4, 4 ]
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
num_attention_blocks: [1, 1, 1, 3]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 4
wrap: false
train:
target: ldm.data.dummy.DummyData
params:
length: 20000
size: [512, 512, 3]
validation:
target: ldm.data.dummy.DummyData
params:
length: 10000
size: [512, 512, 3]
#data:
# target: ldm.data.laion.WebDataModuleFromConfig
# params:
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
# batch_size: 4
# num_workers: 4
# multinode: True
# train:
# shards: '{00000..17279}.tar -'
# shuffle: 10000
# image_key: jpg
# image_transforms:
# - target: torchvision.transforms.Resize
# params:
# size: 512
# interpolation: 3
# - target: torchvision.transforms.RandomCrop
# params:
# size: 512
#
# # NOTE use enough shards to avoid empty validation loops in workers
# validation:
# shards: '{17280..17535}.tar -'
# shuffle: 0
# image_key: jpg
# image_transforms:
# - target: torchvision.transforms.Resize
# params:
# size: 512
# interpolation: 3
# - target: torchvision.transforms.CenterCrop
# params:
# size: 512
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]
trainer:
#replace_sampler_ddp: False
benchmark: True
val_check_interval: 1000 # TODO: 1e10 # really sorry
num_sanity_val_steps: 0
accumulate_grad_batches: 2

View file

@ -194,9 +194,12 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
@ -209,7 +212,7 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
return x return x
@ -224,7 +227,8 @@ class SpatialTransformer(nn.Module):
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): depth=1, dropout=0., context_dim=None,
disable_self_attn=False):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
@ -237,7 +241,8 @@ class SpatialTransformer(nn.Module):
padding=0) padding=0)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
disable_self_attn=disable_self_attn)
for d in range(depth)] for d in range(depth)]
) )

View file

@ -18,6 +18,7 @@ from ldm.modules.diffusionmodules.util import (
timestep_embedding, timestep_embedding,
) )
from ldm.modules.attention import SpatialTransformer from ldm.modules.attention import SpatialTransformer
from ldm.util import exists
# dummy replace # dummy replace
@ -466,6 +467,8 @@ class UNetModel(nn.Module):
context_dim=None, # custom transformer support context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True, legacy=True,
disable_self_attentions=None,
num_attention_blocks=None
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
@ -490,7 +493,25 @@ class UNetModel(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
#self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.") # todo: convert to warning
self.attention_resolutions = attention_resolutions self.attention_resolutions = attention_resolutions
self.dropout = dropout self.dropout = dropout
self.channel_mult = channel_mult self.channel_mult = channel_mult
@ -525,7 +546,7 @@ class UNetModel(nn.Module):
ch = model_channels ch = model_channels
ds = 1 ds = 1
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for nr in range(self.num_res_blocks[level]):
layers = [ layers = [
ResBlock( ResBlock(
ch, ch,
@ -547,6 +568,12 @@ class UNetModel(nn.Module):
if legacy: if legacy:
#num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@ -555,7 +582,8 @@ class UNetModel(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -609,7 +637,7 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
), ),
ResBlock( ResBlock(
@ -625,7 +653,7 @@ class UNetModel(nn.Module):
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]: for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1): for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( ResBlock(
@ -648,6 +676,12 @@ class UNetModel(nn.Module):
if legacy: if legacy:
#num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@ -656,10 +690,11 @@ class UNetModel(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
) )
) )
if level and i == num_res_blocks: if level and i == self.num_res_blocks[level]:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResBlock( ResBlock(