add mask generator to LAION
This commit is contained in:
parent
ddd22b549c
commit
680ab981fd
1 changed files with 19 additions and 0 deletions
|
@ -16,6 +16,8 @@ from webdataset.handlers import warn_and_continue
|
||||||
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask
|
||||||
|
from ldm.data.base import PRNGMixin
|
||||||
|
|
||||||
|
|
||||||
class DataWithWings(torch.utils.data.IterableDataset):
|
class DataWithWings(torch.utils.data.IterableDataset):
|
||||||
|
@ -229,6 +231,23 @@ class AddLR(object):
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class AddMask(PRNGMixin):
|
||||||
|
def __init__(self, size=512):
|
||||||
|
super().__init__()
|
||||||
|
self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
|
x = sample['jpg']
|
||||||
|
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask > 0.5] = 1
|
||||||
|
mask = torch.from_numpy(mask[..., None])
|
||||||
|
sample['mask'] = mask
|
||||||
|
sample['masked_image'] = x * (mask < 0.5)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def example00():
|
def example00():
|
||||||
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
|
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
|
||||||
dataset = wds.WebDataset(url)
|
dataset = wds.WebDataset(url)
|
||||||
|
|
Loading…
Reference in a new issue