add mask generator to LAION

This commit is contained in:
Robin Rombach 2022-07-24 13:23:12 +02:00
parent ddd22b549c
commit 680ab981fd

View file

@ -16,6 +16,8 @@ from webdataset.handlers import warn_and_continue
from ldm.util import instantiate_from_config
from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask
from ldm.data.base import PRNGMixin
class DataWithWings(torch.utils.data.IterableDataset):
@ -229,6 +231,23 @@ class AddLR(object):
return sample
class AddMask(PRNGMixin):
def __init__(self, size=512):
super().__init__()
self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = sample['jpg']
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
mask[mask < 0.5] = 0
mask[mask > 0.5] = 1
mask = torch.from_numpy(mask[..., None])
sample['mask'] = mask
sample['masked_image'] = x * (mask < 0.5)
return sample
def example00():
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
dataset = wds.WebDataset(url)