larger masks
This commit is contained in:
parent
8ce2c914d3
commit
f6016af80a
2 changed files with 26 additions and 9 deletions
|
@ -38,6 +38,18 @@ settings = {
|
|||
"max_s_box": 300,
|
||||
"marg": 10,
|
||||
},
|
||||
"512train-large": { # TODO: experimental
|
||||
"p_irr": 0.5,
|
||||
"min_n_irr": 1,
|
||||
"max_n_irr": 5,
|
||||
"max_l_irr": 450,
|
||||
"max_w_irr": 400,
|
||||
"min_n_box": 1,
|
||||
"max_n_box": 4,
|
||||
"min_s_box": 75,
|
||||
"max_s_box": 450,
|
||||
"marg": 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -128,14 +140,18 @@ def gen_large_mask(prng, img_h, img_w,
|
|||
return mask
|
||||
|
||||
|
||||
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
|
||||
**settings["256train"])
|
||||
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
|
||||
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
|
||||
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
|
||||
make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
|
||||
|
||||
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
|
||||
**settings["256narrow"])
|
||||
|
||||
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
|
||||
**settings["512train"])
|
||||
MASK_MODES = {
|
||||
"256train": make_lama_mask,
|
||||
"256narrow": make_narrow_lama_mask,
|
||||
"512train": make_512_lama_mask,
|
||||
"512train-large": make_512_lama_mask_large
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
|
|
@ -16,7 +16,7 @@ from webdataset.handlers import warn_and_continue
|
|||
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask
|
||||
from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
|
||||
from ldm.data.base import PRNGMixin
|
||||
|
||||
|
||||
|
@ -232,9 +232,10 @@ class AddLR(object):
|
|||
|
||||
|
||||
class AddMask(PRNGMixin):
|
||||
def __init__(self, size=512):
|
||||
def __init__(self, mode="512train"):
|
||||
super().__init__()
|
||||
self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask
|
||||
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
||||
self.make_mask = MASK_MODES[mode]
|
||||
|
||||
def __call__(self, sample):
|
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||
|
|
Loading…
Reference in a new issue