diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 6f4a9bb..ecaef85 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -99,7 +99,7 @@ class BERTEmbedder(AbstractEncoder): return self(text) -from transformers import T5Tokenizer, T5EncoderModel +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode @@ -115,10 +115,9 @@ class FrozenT5Embedder(AbstractEncoder): self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length # TODO: typical value? - self._freeze() + self.freeze() - @staticmethod - def _freeze(self): + def freeze(self): self.transformer = self.transformer.eval() self.train = disabled_train for param in self.parameters(): @@ -128,7 +127,36 @@ class FrozenT5Embedder(AbstractEncoder): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) - outputs = model(input_ids=tokens) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + def __init__(self, version="clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z @@ -172,8 +200,14 @@ class SpatialRescaler(nn.Module): if __name__ == "__main__": from ldm.util import count_params sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] - model = FrozenT5Embedder() + model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() count_params(model, True) z = model(sentences) print(z.shape) - print("done.") \ No newline at end of file + + model = FrozenCLIPEmbedder().cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + print("done.")