larger masks
This commit is contained in:
parent
8ce2c914d3
commit
f6016af80a
2 changed files with 26 additions and 9 deletions
|
@ -38,6 +38,18 @@ settings = {
|
||||||
"max_s_box": 300,
|
"max_s_box": 300,
|
||||||
"marg": 10,
|
"marg": 10,
|
||||||
},
|
},
|
||||||
|
"512train-large": { # TODO: experimental
|
||||||
|
"p_irr": 0.5,
|
||||||
|
"min_n_irr": 1,
|
||||||
|
"max_n_irr": 5,
|
||||||
|
"max_l_irr": 450,
|
||||||
|
"max_w_irr": 400,
|
||||||
|
"min_n_box": 1,
|
||||||
|
"max_n_box": 4,
|
||||||
|
"min_s_box": 75,
|
||||||
|
"max_s_box": 450,
|
||||||
|
"marg": 10,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,14 +140,18 @@ def gen_large_mask(prng, img_h, img_w,
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
|
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
|
||||||
**settings["256train"])
|
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
|
||||||
|
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
|
||||||
|
make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
|
||||||
|
|
||||||
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
|
|
||||||
**settings["256narrow"])
|
|
||||||
|
|
||||||
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w,
|
MASK_MODES = {
|
||||||
**settings["512train"])
|
"256train": make_lama_mask,
|
||||||
|
"256narrow": make_narrow_lama_mask,
|
||||||
|
"512train": make_512_lama_mask,
|
||||||
|
"512train-large": make_512_lama_mask_large
|
||||||
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -16,7 +16,7 @@ from webdataset.handlers import warn_and_continue
|
||||||
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.data.inpainting.synthetic_mask import gen_large_mask, make_lama_mask, make_narrow_lama_mask, make_512_lama_mask
|
from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
|
||||||
from ldm.data.base import PRNGMixin
|
from ldm.data.base import PRNGMixin
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,9 +232,10 @@ class AddLR(object):
|
||||||
|
|
||||||
|
|
||||||
class AddMask(PRNGMixin):
|
class AddMask(PRNGMixin):
|
||||||
def __init__(self, size=512):
|
def __init__(self, mode="512train"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.make_mask = make_512_lama_mask if size == 512 else make_lama_mask
|
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
||||||
|
self.make_mask = MASK_MODES[mode]
|
||||||
|
|
||||||
def __call__(self, sample):
|
def __call__(self, sample):
|
||||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
|
|
Loading…
Reference in a new issue