mask drop and typo

This commit is contained in:
Robin Rombach 2022-08-03 23:30:13 +02:00
parent a416813c32
commit 47aa45a345
2 changed files with 8 additions and 4 deletions

View File

@ -232,15 +232,18 @@ class AddLR(object):
class AddMask(PRNGMixin):
def __init__(self, mode="512train"):
def __init__(self, mode="512train", p_drop=0.):
super().__init__()
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
self.make_mask = MASK_MODES[mode]
self.p_drop = p_drop
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = sample['jpg']
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
mask = np.ones_like(mask)
mask[mask < 0.5] = 0
mask[mask > 0.5] = 1
mask = torch.from_numpy(mask[..., None])

View File

@ -122,7 +122,6 @@ class DDPM(pl.LightningModule):
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
@ -1603,7 +1602,9 @@ class LatentInpaintDiffusion(LatentDiffusion):
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight"),
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight"
),
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
@ -1653,7 +1654,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpaiting'
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)