do not fix num_heads to one in legacy mode
This commit is contained in:
parent
8381e5e557
commit
71ebe40791
1 changed files with 4 additions and 4 deletions
|
@ -545,7 +545,7 @@ class UNetModel(nn.Module):
|
||||||
num_heads = ch // num_head_channels
|
num_heads = ch // num_head_channels
|
||||||
dim_head = num_head_channels
|
dim_head = num_head_channels
|
||||||
if legacy:
|
if legacy:
|
||||||
num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
|
@ -592,7 +592,7 @@ class UNetModel(nn.Module):
|
||||||
num_heads = ch // num_head_channels
|
num_heads = ch // num_head_channels
|
||||||
dim_head = num_head_channels
|
dim_head = num_head_channels
|
||||||
if legacy:
|
if legacy:
|
||||||
num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
self.middle_block = TimestepEmbedSequential(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
|
@ -646,7 +646,7 @@ class UNetModel(nn.Module):
|
||||||
num_heads = ch // num_head_channels
|
num_heads = ch // num_head_channels
|
||||||
dim_head = num_head_channels
|
dim_head = num_head_channels
|
||||||
if legacy:
|
if legacy:
|
||||||
num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
|
|
Loading…
Reference in a new issue