larger masks

This commit is contained in:
Robin Rombach 2022-07-28 00:07:45 +02:00
parent 8ce2c914d3
commit f6016af80a
2 changed files with 26 additions and 9 deletions

View File

@ -38,6 +38,18 @@ settings = {
"max_s_box": 300,
"marg": 10,
},
"512train-large": { # TODO: experimental
"p_irr": 0.5,
"min_n_irr": 1,
"max_n_irr": 5,
"max_l_irr": 450,
"max_w_irr": 400,
"min_n_box": 1,
"max_n_box": 4,
"min_s_box": 75,
"max_s_box": 450,
"marg": 10,
},
}
@ -128,14 +140,18 @@ def gen_large_mask(prng, img_h, img_w,
return mask
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
**settings["256train"])
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
**settings["256narrow"])
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
**settings["512train"])
MASK_MODES = {
"256train": make_lama_mask,
"256narrow": make_narrow_lama_mask,
"512train": make_512_lama_mask,
"512train-large": make_512_lama_mask_large
}
if __name__ == "__main__":
import sys

View File

@ -16,7 +16,7 @@ from webdataset.handlers import warn_and_continue
from ldm.util import instantiate_from_config
from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask
from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
from ldm.data.base import PRNGMixin
@ -232,9 +232,10 @@ class AddLR(object):
class AddMask(PRNGMixin):
def __init__(self, size=512):
def __init__(self, mode="512train"):
super().__init__()
self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
self.make_mask = MASK_MODES[mode]
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point