clip encoder
This commit is contained in:
parent
3c53c6c15b
commit
e637491838
1 changed files with 41 additions and 7 deletions
|
@ -99,7 +99,7 @@ class BERTEmbedder(AbstractEncoder):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
from transformers import T5Tokenizer, T5EncoderModel
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
def disabled_train(self, mode=True):
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
@ -115,10 +115,9 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length # TODO: typical value?
|
self.max_length = max_length # TODO: typical value?
|
||||||
self._freeze()
|
self.freeze()
|
||||||
|
|
||||||
@staticmethod
|
def freeze(self):
|
||||||
def _freeze(self):
|
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
self.train = disabled_train
|
self.train = disabled_train
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
|
@ -128,7 +127,36 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
outputs = model(input_ids=tokens)
|
outputs = self.transformer(input_ids=tokens)
|
||||||
|
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
|
def __init__(self, version="clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length # TODO: typical value?
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
self.train = disabled_train
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
outputs = self.transformer(input_ids=tokens)
|
||||||
|
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
return z
|
return z
|
||||||
|
@ -172,8 +200,14 @@ class SpatialRescaler(nn.Module):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from ldm.util import count_params
|
from ldm.util import count_params
|
||||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||||
model = FrozenT5Embedder()
|
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||||
count_params(model, True)
|
count_params(model, True)
|
||||||
z = model(sentences)
|
z = model(sentences)
|
||||||
print(z.shape)
|
print(z.shape)
|
||||||
|
|
||||||
|
model = FrozenCLIPEmbedder().cuda()
|
||||||
|
count_params(model, True)
|
||||||
|
z = model(sentences)
|
||||||
|
print(z.shape)
|
||||||
|
|
||||||
print("done.")
|
print("done.")
|
Loading…
Reference in a new issue