add frozen t5 encoder from huggingface

This commit is contained in:
rromb 2022-05-31 11:42:53 +02:00
parent 48bbc51869
commit 3c53c6c15b

View file

@ -99,6 +99,44 @@ class BERTEmbedder(AbstractEncoder):
return self(text)
from transformers import T5Tokenizer, T5EncoderModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self._freeze()
@staticmethod
def _freeze(self):
self.transformer = self.transformer.eval()
self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = model(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
@ -129,3 +167,13 @@ class SpatialRescaler(nn.Module):
def encode(self, x):
return self(x)
if __name__ == "__main__":
from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
model = FrozenT5Embedder()
count_params(model, True)
z = model(sentences)
print(z.shape)
print("done.")