do not fix num_heads to one in legacy mode
This commit is contained in:
parent
8381e5e557
commit
71ebe40791
1 changed files with 4 additions and 4 deletions
|
@ -464,7 +464,7 @@ class UNetModel(nn.Module):
|
|||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -545,7 +545,7 @@ class UNetModel(nn.Module):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
num_heads = 1
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
|
@ -592,7 +592,7 @@ class UNetModel(nn.Module):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
num_heads = 1
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
|
@ -646,7 +646,7 @@ class UNetModel(nn.Module):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
num_heads = 1
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
|
|
Loading…
Reference in a new issue