diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 7a41d4a..bd05dc8 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -232,15 +232,18 @@ class AddLR(object): class AddMask(PRNGMixin): - def __init__(self, mode="512train"): + def __init__(self, mode="512train", p_drop=0.): super().__init__() assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' self.make_mask = MASK_MODES[mode] + self.p_drop = p_drop def __call__(self, sample): # sample['jpg'] is tensor hwc in [-1, 1] at this point x = sample['jpg'] mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]): + mask = np.ones_like(mask) mask[mask < 0.5] = 0 mask[mask > 0.5] = 1 mask = torch.from_numpy(mask[..., None]) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index b078983..e0d2e35 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -122,7 +122,6 @@ class DDPM(pl.LightningModule): if self.ucg_training: self.ucg_prng = np.random.RandomState() - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if exists(given_betas): @@ -1603,7 +1602,9 @@ class LatentInpaintDiffusion(LatentDiffusion): To disable finetuning mode, set finetune_keys to None """ def __init__(self, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight"), + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), concat_keys=("mask", "masked_image"), masked_image_key="masked_image", keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels @@ -1653,7 +1654,7 @@ class LatentInpaintDiffusion(LatentDiffusion): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpaiting' + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=bs)