do not fix num_heads to one in legacy mode

This commit is contained in:
rromb 2022-04-15 17:24:11 +02:00
parent 8381e5e557
commit 71ebe40791

View file

@ -545,7 +545,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append( layers.append(
AttentionBlock( AttentionBlock(
@ -592,7 +592,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
@ -646,7 +646,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append( layers.append(
AttentionBlock( AttentionBlock(