do not fix num_heads to one in legacy mode
This commit is contained in:
parent
8381e5e557
commit
71ebe40791
1 changed files with 4 additions and 4 deletions
|
@ -545,7 +545,7 @@ class UNetModel(nn.Module):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
num_heads = 1
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
|
@ -592,7 +592,7 @@ class UNetModel(nn.Module):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
num_heads = 1
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
|
@ -646,7 +646,7 @@ class UNetModel(nn.Module):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
num_heads = 1
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
|
|
Loading…
Reference in a new issue