Update modules.py

This commit is contained in:
Justin 2022-09-26 10:39:44 +01:00 committed by GitHub
parent 5346889189
commit 3a64aae085
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -188,6 +188,7 @@ class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
class FrozenCLIPImageEmbedder(AbstractEncoder): class FrozenCLIPImageEmbedder(AbstractEncoder):
""" """
Uses the CLIP image encoder. Uses the CLIP image encoder.
Not actually frozen...
""" """
def __init__( def __init__(
self, self,
@ -205,14 +206,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
# I didn't call this originally, but seems like it was frozen anyway
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def preprocess(self, x): def preprocess(self, x):
# Expects inputs in the range -1, 1 # Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224), x = kornia.geometry.resize(x, (224, 224),