add further nerf-attention
This commit is contained in:
parent
0540e685e5
commit
e2a6bee13c
2 changed files with 41 additions and 28 deletions
|
@ -32,11 +32,12 @@ model:
|
||||||
image_size: 64 # unused
|
image_size: 64 # unused
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 352
|
model_channels: 384
|
||||||
attention_resolutions: [ 8, 4, 2 ]
|
attention_resolutions: [ 8, 4, 2, 1 ]
|
||||||
num_res_blocks: [ 2, 2, 2, 6 ]
|
num_res_blocks: [ 2, 2, 2, 5 ]
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
disable_self_attentions: [ True, True, True, False ] # converts the self-attention to a cross-attention layer if true
|
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
|
||||||
|
num_attention_blocks: [1, 1, 1, 3]
|
||||||
num_heads: 8
|
num_heads: 8
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
|
|
|
@ -467,7 +467,8 @@ class UNetModel(nn.Module):
|
||||||
context_dim=None, # custom transformer support
|
context_dim=None, # custom transformer support
|
||||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
legacy=True,
|
legacy=True,
|
||||||
disable_self_attentions=None
|
disable_self_attentions=None,
|
||||||
|
num_attention_blocks=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
|
@ -503,6 +504,13 @@ class UNetModel(nn.Module):
|
||||||
if disable_self_attentions is not None:
|
if disable_self_attentions is not None:
|
||||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||||
assert len(disable_self_attentions) == len(channel_mult)
|
assert len(disable_self_attentions) == len(channel_mult)
|
||||||
|
if num_attention_blocks is not None:
|
||||||
|
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||||
|
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||||
|
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||||
|
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||||
|
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||||
|
f"attention will still not be set.") # todo: convert to warning
|
||||||
|
|
||||||
self.attention_resolutions = attention_resolutions
|
self.attention_resolutions = attention_resolutions
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
@ -538,7 +546,7 @@ class UNetModel(nn.Module):
|
||||||
ch = model_channels
|
ch = model_channels
|
||||||
ds = 1
|
ds = 1
|
||||||
for level, mult in enumerate(channel_mult):
|
for level, mult in enumerate(channel_mult):
|
||||||
for _ in range(self.num_res_blocks[level]):
|
for nr in range(self.num_res_blocks[level]):
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -564,18 +572,20 @@ class UNetModel(nn.Module):
|
||||||
disabled_sa = disable_self_attentions[level]
|
disabled_sa = disable_self_attentions[level]
|
||||||
else:
|
else:
|
||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
layers.append(
|
|
||||||
AttentionBlock(
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
ch,
|
layers.append(
|
||||||
use_checkpoint=use_checkpoint,
|
AttentionBlock(
|
||||||
num_heads=num_heads,
|
ch,
|
||||||
num_head_channels=dim_head,
|
use_checkpoint=use_checkpoint,
|
||||||
use_new_attention_order=use_new_attention_order,
|
num_heads=num_heads,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
num_head_channels=dim_head,
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
use_new_attention_order=use_new_attention_order,
|
||||||
disable_self_attn=disabled_sa
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disabled_sa
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
|
@ -670,18 +680,20 @@ class UNetModel(nn.Module):
|
||||||
disabled_sa = disable_self_attentions[level]
|
disabled_sa = disable_self_attentions[level]
|
||||||
else:
|
else:
|
||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
layers.append(
|
|
||||||
AttentionBlock(
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
ch,
|
layers.append(
|
||||||
use_checkpoint=use_checkpoint,
|
AttentionBlock(
|
||||||
num_heads=num_heads_upsample,
|
ch,
|
||||||
num_head_channels=dim_head,
|
use_checkpoint=use_checkpoint,
|
||||||
use_new_attention_order=use_new_attention_order,
|
num_heads=num_heads_upsample,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
num_head_channels=dim_head,
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
use_new_attention_order=use_new_attention_order,
|
||||||
disable_self_attn=disabled_sa
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disabled_sa
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if level and i == self.num_res_blocks[level]:
|
if level and i == self.num_res_blocks[level]:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(
|
||||||
|
|
Loading…
Reference in a new issue