add frozen t5 encoder from huggingface
This commit is contained in:
parent
48bbc51869
commit
3c53c6c15b
1 changed files with 48 additions and 0 deletions
|
@ -99,6 +99,44 @@ class BERTEmbedder(AbstractEncoder):
|
|||
return self(text)
|
||||
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self._freeze()
|
||||
|
||||
@staticmethod
|
||||
def _freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = model(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
|
@ -129,3 +167,13 @@ class SpatialRescaler(nn.Module):
|
|||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||
model = FrozenT5Embedder()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
print("done.")
|
Loading…
Reference in a new issue