2021-12-21 02:23:41 +00:00
import torch
import torch . nn as nn
from functools import partial
from ldm . modules . x_transformer import Encoder , TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
class AbstractEncoder ( nn . Module ) :
def __init__ ( self ) :
super ( ) . __init__ ( )
def encode ( self , * args , * * kwargs ) :
raise NotImplementedError
class ClassEmbedder ( nn . Module ) :
def __init__ ( self , embed_dim , n_classes = 1000 , key = ' class ' ) :
super ( ) . __init__ ( )
self . key = key
self . embedding = nn . Embedding ( n_classes , embed_dim )
def forward ( self , batch , key = None ) :
if key is None :
key = self . key
# this is for use in crossattn
c = batch [ key ] [ : , None ]
c = self . embedding ( c )
return c
class TransformerEmbedder ( AbstractEncoder ) :
""" Some transformer encoder layers """
def __init__ ( self , n_embed , n_layer , vocab_size , max_seq_len = 77 , device = " cuda " ) :
super ( ) . __init__ ( )
self . device = device
self . transformer = TransformerWrapper ( num_tokens = vocab_size , max_seq_len = max_seq_len ,
attn_layers = Encoder ( dim = n_embed , depth = n_layer ) )
def forward ( self , tokens ) :
tokens = tokens . to ( self . device ) # meh
z = self . transformer ( tokens , return_embeddings = True )
return z
def encode ( self , x ) :
return self ( x )
class BERTTokenizer ( AbstractEncoder ) :
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?) """
def __init__ ( self , device = " cuda " , vq_interface = True , max_length = 77 ) :
super ( ) . __init__ ( )
from transformers import BertTokenizerFast # TODO: add to reuquirements
self . tokenizer = BertTokenizerFast . from_pretrained ( " bert-base-uncased " )
self . device = device
self . vq_interface = vq_interface
self . max_length = max_length
def forward ( self , text ) :
batch_encoding = self . tokenizer ( text , truncation = True , max_length = self . max_length , return_length = True ,
return_overflowing_tokens = False , padding = " max_length " , return_tensors = " pt " )
tokens = batch_encoding [ " input_ids " ] . to ( self . device )
return tokens
@torch.no_grad ( )
def encode ( self , text ) :
tokens = self ( text )
if not self . vq_interface :
return tokens
return None , None , [ None , None , tokens ]
def decode ( self , text ) :
return text
class BERTEmbedder ( AbstractEncoder ) :
""" Uses the BERT tokenizr model and add some transformer encoder layers """
def __init__ ( self , n_embed , n_layer , vocab_size = 30522 , max_seq_len = 77 ,
device = " cuda " , use_tokenizer = True , embedding_dropout = 0.0 ) :
super ( ) . __init__ ( )
self . use_tknz_fn = use_tokenizer
if self . use_tknz_fn :
self . tknz_fn = BERTTokenizer ( vq_interface = False , max_length = max_seq_len )
self . device = device
self . transformer = TransformerWrapper ( num_tokens = vocab_size , max_seq_len = max_seq_len ,
attn_layers = Encoder ( dim = n_embed , depth = n_layer ) ,
emb_dropout = embedding_dropout )
def forward ( self , text ) :
if self . use_tknz_fn :
tokens = self . tknz_fn ( text ) #.to(self.device)
else :
tokens = text
z = self . transformer ( tokens , return_embeddings = True )
return z
def encode ( self , text ) :
# output of length 77
return self ( text )
2022-05-31 09:42:53 +00:00
from transformers import T5Tokenizer , T5EncoderModel
def disabled_train ( self , mode = True ) :
""" Overwrite model.train with this function to make sure train/eval mode
does not change anymore . """
return self
class FrozenT5Embedder ( AbstractEncoder ) :
""" Uses the T5 transformer encoder for text """
def __init__ ( self , version = " google/t5-v1_1-large " , device = " cuda " , max_length = 77 ) : # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super ( ) . __init__ ( )
self . tokenizer = T5Tokenizer . from_pretrained ( version )
self . transformer = T5EncoderModel . from_pretrained ( version )
self . device = device
self . max_length = max_length # TODO: typical value?
self . _freeze ( )
@staticmethod
def _freeze ( self ) :
self . transformer = self . transformer . eval ( )
self . train = disabled_train
for param in self . parameters ( ) :
param . requires_grad = False
def forward ( self , text ) :
batch_encoding = self . tokenizer ( text , truncation = True , max_length = self . max_length , return_length = True ,
return_overflowing_tokens = False , padding = " max_length " , return_tensors = " pt " )
tokens = batch_encoding [ " input_ids " ] . to ( self . device )
outputs = model ( input_ids = tokens )
z = outputs . last_hidden_state
return z
def encode ( self , text ) :
return self ( text )
2021-12-21 02:23:41 +00:00
class SpatialRescaler ( nn . Module ) :
def __init__ ( self ,
n_stages = 1 ,
method = ' bilinear ' ,
multiplier = 0.5 ,
in_channels = 3 ,
out_channels = None ,
bias = False ) :
super ( ) . __init__ ( )
self . n_stages = n_stages
assert self . n_stages > = 0
assert method in [ ' nearest ' , ' linear ' , ' bilinear ' , ' trilinear ' , ' bicubic ' , ' area ' ]
self . multiplier = multiplier
self . interpolator = partial ( torch . nn . functional . interpolate , mode = method )
self . remap_output = out_channels is not None
if self . remap_output :
print ( f ' Spatial Rescaler mapping from { in_channels } to { out_channels } channels after resizing. ' )
self . channel_mapper = nn . Conv2d ( in_channels , out_channels , 1 , bias = bias )
def forward ( self , x ) :
for stage in range ( self . n_stages ) :
x = self . interpolator ( x , scale_factor = self . multiplier )
if self . remap_output :
x = self . channel_mapper ( x )
return x
def encode ( self , x ) :
return self ( x )
2022-05-31 09:42:53 +00:00
if __name__ == " __main__ " :
from ldm . util import count_params
sentences = [ " a hedgehog drinking a whiskey " , " der mond ist aufgegangen " , " Ein Satz mit vielen Sonderzeichen: äöü ß ?! : ' xx-y/@s ' " ]
model = FrozenT5Embedder ( )
count_params ( model , True )
z = model ( sentences )
print ( z . shape )
print ( " done. " )