clip encoder

This commit is contained in:
rromb 2022-05-31 12:28:00 +02:00
parent 3c53c6c15b
commit e637491838

View file

@ -99,7 +99,7 @@ class BERTEmbedder(AbstractEncoder):
return self(text)
from transformers import T5Tokenizer, T5EncoderModel
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
@ -115,10 +115,9 @@ class FrozenT5Embedder(AbstractEncoder):
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self._freeze()
self.freeze()
@staticmethod
def _freeze(self):
def freeze(self):
self.transformer = self.transformer.eval()
self.train = disabled_train
for param in self.parameters():
@ -128,7 +127,36 @@ class FrozenT5Embedder(AbstractEncoder):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = model(input_ids=tokens)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
@ -172,8 +200,14 @@ class SpatialRescaler(nn.Module):
if __name__ == "__main__":
from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
model = FrozenT5Embedder()
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
model = FrozenCLIPEmbedder().cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
print("done.")