2021-12-21 03:23:41 +01:00
import torch
import torch . nn as nn
2022-06-13 00:39:48 +02:00
import numpy as np
2021-12-21 03:23:41 +01:00
from functools import partial
from ldm . modules . x_transformer import Encoder , TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
2022-06-13 00:39:48 +02:00
from ldm . util import default
2021-12-21 03:23:41 +01:00
class AbstractEncoder ( nn . Module ) :
def __init__ ( self ) :
super ( ) . __init__ ( )
def encode ( self , * args , * * kwargs ) :
raise NotImplementedError
2022-06-05 19:22:07 +02:00
class IdentityEncoder ( AbstractEncoder ) :
def encode ( self , x ) :
return x
2021-12-21 03:23:41 +01:00
class ClassEmbedder ( nn . Module ) :
def __init__ ( self , embed_dim , n_classes = 1000 , key = ' class ' ) :
super ( ) . __init__ ( )
self . key = key
self . embedding = nn . Embedding ( n_classes , embed_dim )
def forward ( self , batch , key = None ) :
if key is None :
key = self . key
# this is for use in crossattn
c = batch [ key ] [ : , None ]
c = self . embedding ( c )
return c
class TransformerEmbedder ( AbstractEncoder ) :
""" Some transformer encoder layers """
def __init__ ( self , n_embed , n_layer , vocab_size , max_seq_len = 77 , device = " cuda " ) :
super ( ) . __init__ ( )
self . device = device
self . transformer = TransformerWrapper ( num_tokens = vocab_size , max_seq_len = max_seq_len ,
attn_layers = Encoder ( dim = n_embed , depth = n_layer ) )
def forward ( self , tokens ) :
tokens = tokens . to ( self . device ) # meh
z = self . transformer ( tokens , return_embeddings = True )
return z
def encode ( self , x ) :
return self ( x )
class BERTTokenizer ( AbstractEncoder ) :
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?) """
def __init__ ( self , device = " cuda " , vq_interface = True , max_length = 77 ) :
super ( ) . __init__ ( )
from transformers import BertTokenizerFast # TODO: add to reuquirements
self . tokenizer = BertTokenizerFast . from_pretrained ( " bert-base-uncased " )
self . device = device
self . vq_interface = vq_interface
self . max_length = max_length
def forward ( self , text ) :
batch_encoding = self . tokenizer ( text , truncation = True , max_length = self . max_length , return_length = True ,
return_overflowing_tokens = False , padding = " max_length " , return_tensors = " pt " )
tokens = batch_encoding [ " input_ids " ] . to ( self . device )
return tokens
@torch.no_grad ( )
def encode ( self , text ) :
tokens = self ( text )
if not self . vq_interface :
return tokens
return None , None , [ None , None , tokens ]
def decode ( self , text ) :
return text
class BERTEmbedder ( AbstractEncoder ) :
""" Uses the BERT tokenizr model and add some transformer encoder layers """
def __init__ ( self , n_embed , n_layer , vocab_size = 30522 , max_seq_len = 77 ,
device = " cuda " , use_tokenizer = True , embedding_dropout = 0.0 ) :
super ( ) . __init__ ( )
self . use_tknz_fn = use_tokenizer
if self . use_tknz_fn :
self . tknz_fn = BERTTokenizer ( vq_interface = False , max_length = max_seq_len )
self . device = device
self . transformer = TransformerWrapper ( num_tokens = vocab_size , max_seq_len = max_seq_len ,
attn_layers = Encoder ( dim = n_embed , depth = n_layer ) ,
emb_dropout = embedding_dropout )
def forward ( self , text ) :
if self . use_tknz_fn :
tokens = self . tknz_fn ( text ) #.to(self.device)
else :
tokens = text
z = self . transformer ( tokens , return_embeddings = True )
return z
def encode ( self , text ) :
# output of length 77
return self ( text )
2022-05-31 12:28:00 +02:00
from transformers import T5Tokenizer , T5EncoderModel , CLIPTokenizer , CLIPTextModel
2022-05-31 11:42:53 +02:00
def disabled_train ( self , mode = True ) :
""" Overwrite model.train with this function to make sure train/eval mode
does not change anymore . """
return self
class FrozenT5Embedder ( AbstractEncoder ) :
""" Uses the T5 transformer encoder for text """
def __init__ ( self , version = " google/t5-v1_1-large " , device = " cuda " , max_length = 77 ) : # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super ( ) . __init__ ( )
self . tokenizer = T5Tokenizer . from_pretrained ( version )
self . transformer = T5EncoderModel . from_pretrained ( version )
self . device = device
self . max_length = max_length # TODO: typical value?
2022-05-31 12:28:00 +02:00
self . freeze ( )
2022-05-31 11:42:53 +02:00
2022-05-31 12:28:00 +02:00
def freeze ( self ) :
2022-05-31 11:42:53 +02:00
self . transformer = self . transformer . eval ( )
2022-05-31 14:18:01 +02:00
#self.train = disabled_train
2022-05-31 11:42:53 +02:00
for param in self . parameters ( ) :
param . requires_grad = False
def forward ( self , text ) :
batch_encoding = self . tokenizer ( text , truncation = True , max_length = self . max_length , return_length = True ,
return_overflowing_tokens = False , padding = " max_length " , return_tensors = " pt " )
tokens = batch_encoding [ " input_ids " ] . to ( self . device )
2022-05-31 12:28:00 +02:00
outputs = self . transformer ( input_ids = tokens )
z = outputs . last_hidden_state
return z
def encode ( self , text ) :
return self ( text )
class FrozenCLIPEmbedder ( AbstractEncoder ) :
""" Uses the CLIP transformer encoder for text (from huggingface) """
2022-05-31 12:36:26 +02:00
def __init__ ( self , version = " openai/clip-vit-large-patch14 " , device = " cuda " , max_length = 77 ) : # clip-vit-base-patch32
2022-05-31 12:28:00 +02:00
super ( ) . __init__ ( )
self . tokenizer = CLIPTokenizer . from_pretrained ( version )
self . transformer = CLIPTextModel . from_pretrained ( version )
self . device = device
self . max_length = max_length # TODO: typical value?
self . freeze ( )
def freeze ( self ) :
self . transformer = self . transformer . eval ( )
2022-05-31 14:18:01 +02:00
#self.train = disabled_train
2022-05-31 12:28:00 +02:00
for param in self . parameters ( ) :
param . requires_grad = False
def forward ( self , text ) :
batch_encoding = self . tokenizer ( text , truncation = True , max_length = self . max_length , return_length = True ,
return_overflowing_tokens = False , padding = " max_length " , return_tensors = " pt " )
tokens = batch_encoding [ " input_ids " ] . to ( self . device )
outputs = self . transformer ( input_ids = tokens )
2022-05-31 11:42:53 +02:00
z = outputs . last_hidden_state
return z
def encode ( self , text ) :
return self ( text )
2021-12-21 03:23:41 +01:00
class SpatialRescaler ( nn . Module ) :
def __init__ ( self ,
n_stages = 1 ,
method = ' bilinear ' ,
multiplier = 0.5 ,
in_channels = 3 ,
out_channels = None ,
bias = False ) :
super ( ) . __init__ ( )
self . n_stages = n_stages
assert self . n_stages > = 0
assert method in [ ' nearest ' , ' linear ' , ' bilinear ' , ' trilinear ' , ' bicubic ' , ' area ' ]
self . multiplier = multiplier
self . interpolator = partial ( torch . nn . functional . interpolate , mode = method )
self . remap_output = out_channels is not None
if self . remap_output :
print ( f ' Spatial Rescaler mapping from { in_channels } to { out_channels } channels after resizing. ' )
self . channel_mapper = nn . Conv2d ( in_channels , out_channels , 1 , bias = bias )
def forward ( self , x ) :
for stage in range ( self . n_stages ) :
x = self . interpolator ( x , scale_factor = self . multiplier )
if self . remap_output :
x = self . channel_mapper ( x )
return x
def encode ( self , x ) :
return self ( x )
2022-05-31 11:42:53 +02:00
2022-06-13 00:39:48 +02:00
from ldm . util import instantiate_from_config
from ldm . modules . diffusionmodules . util import make_beta_schedule , extract_into_tensor , noise_like
class LowScaleEncoder ( nn . Module ) :
2022-06-13 10:43:41 +02:00
def __init__ ( self , model_config , linear_start , linear_end , timesteps = 1000 , max_noise_level = 250 , output_size = 64 ,
scale_factor = 1.0 ) :
2022-06-13 00:39:48 +02:00
super ( ) . __init__ ( )
self . max_noise_level = max_noise_level
self . model = instantiate_from_config ( model_config )
self . augmentation_schedule = self . register_schedule ( timesteps = timesteps , linear_start = linear_start ,
linear_end = linear_end )
self . out_size = output_size
2022-06-13 10:43:41 +02:00
self . scale_factor = scale_factor
2022-06-13 00:39:48 +02:00
def register_schedule ( self , beta_schedule = " linear " , timesteps = 1000 ,
linear_start = 1e-4 , linear_end = 2e-2 , cosine_s = 8e-3 ) :
betas = make_beta_schedule ( beta_schedule , timesteps , linear_start = linear_start , linear_end = linear_end ,
cosine_s = cosine_s )
alphas = 1. - betas
alphas_cumprod = np . cumprod ( alphas , axis = 0 )
alphas_cumprod_prev = np . append ( 1. , alphas_cumprod [ : - 1 ] )
timesteps , = betas . shape
self . num_timesteps = int ( timesteps )
self . linear_start = linear_start
self . linear_end = linear_end
assert alphas_cumprod . shape [ 0 ] == self . num_timesteps , ' alphas have to be defined for each timestep '
to_torch = partial ( torch . tensor , dtype = torch . float32 )
self . register_buffer ( ' betas ' , to_torch ( betas ) )
self . register_buffer ( ' alphas_cumprod ' , to_torch ( alphas_cumprod ) )
self . register_buffer ( ' alphas_cumprod_prev ' , to_torch ( alphas_cumprod_prev ) )
# calculations for diffusion q(x_t | x_{t-1}) and others
self . register_buffer ( ' sqrt_alphas_cumprod ' , to_torch ( np . sqrt ( alphas_cumprod ) ) )
self . register_buffer ( ' sqrt_one_minus_alphas_cumprod ' , to_torch ( np . sqrt ( 1. - alphas_cumprod ) ) )
self . register_buffer ( ' log_one_minus_alphas_cumprod ' , to_torch ( np . log ( 1. - alphas_cumprod ) ) )
self . register_buffer ( ' sqrt_recip_alphas_cumprod ' , to_torch ( np . sqrt ( 1. / alphas_cumprod ) ) )
self . register_buffer ( ' sqrt_recipm1_alphas_cumprod ' , to_torch ( np . sqrt ( 1. / alphas_cumprod - 1 ) ) )
def q_sample ( self , x_start , t , noise = None ) :
noise = default ( noise , lambda : torch . randn_like ( x_start ) )
return ( extract_into_tensor ( self . sqrt_alphas_cumprod , t , x_start . shape ) * x_start +
extract_into_tensor ( self . sqrt_one_minus_alphas_cumprod , t , x_start . shape ) * noise )
def forward ( self , x ) :
z = self . model . encode ( x ) . sample ( )
2022-06-13 10:43:41 +02:00
z = z * self . scale_factor
2022-06-13 00:39:48 +02:00
noise_level = torch . randint ( 0 , self . max_noise_level , ( x . shape [ 0 ] , ) , device = x . device ) . long ( )
z = self . q_sample ( z , noise_level )
2022-07-28 00:08:46 +02:00
if self . out_size is not None :
z = torch . nn . functional . interpolate ( z , size = self . out_size , mode = " nearest " ) # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
2022-06-13 00:39:48 +02:00
return z , noise_level
2022-06-13 10:43:41 +02:00
def decode ( self , z ) :
z = z / self . scale_factor
return self . model . decode ( z )
2022-06-13 00:39:48 +02:00
2022-05-31 11:42:53 +02:00
if __name__ == " __main__ " :
from ldm . util import count_params
sentences = [ " a hedgehog drinking a whiskey " , " der mond ist aufgegangen " , " Ein Satz mit vielen Sonderzeichen: äöü ß ?! : ' xx-y/@s ' " ]
2022-05-31 12:28:00 +02:00
model = FrozenT5Embedder ( version = " google/t5-v1_1-xl " ) . cuda ( )
count_params ( model , True )
z = model ( sentences )
print ( z . shape )
model = FrozenCLIPEmbedder ( ) . cuda ( )
2022-05-31 11:42:53 +02:00
count_params ( model , True )
z = model ( sentences )
print ( z . shape )
2022-05-31 12:28:00 +02:00
print ( " done. " )