mask drop and typo
This commit is contained in:
parent
a416813c32
commit
47aa45a345
2 changed files with 8 additions and 4 deletions
|
@ -232,15 +232,18 @@ class AddLR(object):
|
||||||
|
|
||||||
|
|
||||||
class AddMask(PRNGMixin):
|
class AddMask(PRNGMixin):
|
||||||
def __init__(self, mode="512train"):
|
def __init__(self, mode="512train", p_drop=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
||||||
self.make_mask = MASK_MODES[mode]
|
self.make_mask = MASK_MODES[mode]
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
def __call__(self, sample):
|
def __call__(self, sample):
|
||||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
x = sample['jpg']
|
x = sample['jpg']
|
||||||
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
||||||
|
if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
|
||||||
|
mask = np.ones_like(mask)
|
||||||
mask[mask < 0.5] = 0
|
mask[mask < 0.5] = 0
|
||||||
mask[mask > 0.5] = 1
|
mask[mask > 0.5] = 1
|
||||||
mask = torch.from_numpy(mask[..., None])
|
mask = torch.from_numpy(mask[..., None])
|
||||||
|
|
|
@ -122,7 +122,6 @@ class DDPM(pl.LightningModule):
|
||||||
if self.ucg_training:
|
if self.ucg_training:
|
||||||
self.ucg_prng = np.random.RandomState()
|
self.ucg_prng = np.random.RandomState()
|
||||||
|
|
||||||
|
|
||||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
if exists(given_betas):
|
if exists(given_betas):
|
||||||
|
@ -1603,7 +1602,9 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
To disable finetuning mode, set finetune_keys to None
|
To disable finetuning mode, set finetune_keys to None
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight"),
|
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
|
||||||
|
"model_ema.diffusion_modelinput_blocks00weight"
|
||||||
|
),
|
||||||
concat_keys=("mask", "masked_image"),
|
concat_keys=("mask", "masked_image"),
|
||||||
masked_image_key="masked_image",
|
masked_image_key="masked_image",
|
||||||
keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
|
keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
|
||||||
|
@ -1653,7 +1654,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
|
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
|
||||||
# note: restricted to non-trainable encoders currently
|
# note: restricted to non-trainable encoders currently
|
||||||
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpaiting'
|
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
|
||||||
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
|
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
|
||||||
force_c_encode=True, return_original_cond=True, bs=bs)
|
force_c_encode=True, return_original_cond=True, bs=bs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue