From f6016af80acc628ada39749d55a818e21faf089c Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Thu, 28 Jul 2022 00:07:45 +0200 Subject: [PATCH] larger masks --- ldm/data/inpainting/synthetic_mask.py | 28 +++++++++++++++++++++------ ldm/data/laion.py | 7 ++++--- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/ldm/data/inpainting/synthetic_mask.py b/ldm/data/inpainting/synthetic_mask.py index 9dcc3d6..bb4c38f 100644 --- a/ldm/data/inpainting/synthetic_mask.py +++ b/ldm/data/inpainting/synthetic_mask.py @@ -38,6 +38,18 @@ settings = { "max_s_box": 300, "marg": 10, }, + "512train-large": { # TODO: experimental + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 450, + "max_w_irr": 400, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 75, + "max_s_box": 450, + "marg": 10, + }, } @@ -128,14 +140,18 @@ def gen_large_mask(prng, img_h, img_w, return mask -make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, - **settings["256train"]) +make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) +make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) +make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) +make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) -make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, - **settings["256narrow"]) -make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, - **settings["512train"]) +MASK_MODES = { + "256train": make_lama_mask, + "256narrow": make_narrow_lama_mask, + "512train": make_512_lama_mask, + "512train-large": make_512_lama_mask_large +} if __name__ == "__main__": import sys diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 4c9b73d..7a41d4a 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -16,7 +16,7 @@ from webdataset.handlers import warn_and_continue from ldm.util import instantiate_from_config -from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask +from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES from ldm.data.base import PRNGMixin @@ -232,9 +232,10 @@ class AddLR(object): class AddMask(PRNGMixin): - def __init__(self, size=512): + def __init__(self, mode="512train"): super().__init__() - self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask + assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' + self.make_mask = MASK_MODES[mode] def __call__(self, sample): # sample['jpg'] is tensor hwc in [-1, 1] at this point