clip encoder
This commit is contained in:
parent
3c53c6c15b
commit
e637491838
1 changed files with 41 additions and 7 deletions
|
@ -99,7 +99,7 @@ class BERTEmbedder(AbstractEncoder):
|
|||
return self(text)
|
||||
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
|
@ -115,10 +115,9 @@ class FrozenT5Embedder(AbstractEncoder):
|
|||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self._freeze()
|
||||
self.freeze()
|
||||
|
||||
@staticmethod
|
||||
def _freeze(self):
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
|
@ -128,7 +127,36 @@ class FrozenT5Embedder(AbstractEncoder):
|
|||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = model(input_ids=tokens)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
def __init__(self, version="clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
@ -172,8 +200,14 @@ class SpatialRescaler(nn.Module):
|
|||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||
model = FrozenT5Embedder()
|
||||
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
model = FrozenCLIPEmbedder().cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
print("done.")
|
Loading…
Reference in a new issue