From 680ab981fdf3db3c1b919f59d2110557932f16c3 Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Sun, 24 Jul 2022 13:23:12 +0200 Subject: [PATCH] add mask generator to LAION --- ldm/data/laion.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 7c07d38..aeee987 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -16,6 +16,8 @@ from webdataset.handlers import warn_and_continue from ldm.util import instantiate_from_config +from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask +from ldm.data.base import PRNGMixin class DataWithWings(torch.utils.data.IterableDataset): @@ -229,6 +231,23 @@ class AddLR(object): return sample +class AddMask(PRNGMixin): + def __init__(self, size=512): + super().__init__() + self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample['jpg'] + mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask[..., None]) + sample['mask'] = mask + sample['masked_image'] = x * (mask < 0.5) + return sample + + def example00(): url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" dataset = wds.WebDataset(url)