add mask generator to LAION
This commit is contained in:
parent
ddd22b549c
commit
680ab981fd
1 changed files with 19 additions and 0 deletions
|
@ -16,6 +16,8 @@ from webdataset.handlers import warn_and_continue
|
|||
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask
|
||||
from ldm.data.base import PRNGMixin
|
||||
|
||||
|
||||
class DataWithWings(torch.utils.data.IterableDataset):
|
||||
|
@ -229,6 +231,23 @@ class AddLR(object):
|
|||
return sample
|
||||
|
||||
|
||||
class AddMask(PRNGMixin):
|
||||
def __init__(self, size=512):
|
||||
super().__init__()
|
||||
self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask
|
||||
|
||||
def __call__(self, sample):
|
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||
x = sample['jpg']
|
||||
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask > 0.5] = 1
|
||||
mask = torch.from_numpy(mask[..., None])
|
||||
sample['mask'] = mask
|
||||
sample['masked_image'] = x * (mask < 0.5)
|
||||
return sample
|
||||
|
||||
|
||||
def example00():
|
||||
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
|
||||
dataset = wds.WebDataset(url)
|
||||
|
|
Loading…
Reference in a new issue