mask drop and typo
This commit is contained in:
parent
a416813c32
commit
47aa45a345
2 changed files with 8 additions and 4 deletions
|
@ -232,15 +232,18 @@ class AddLR(object):
|
|||
|
||||
|
||||
class AddMask(PRNGMixin):
|
||||
def __init__(self, mode="512train"):
|
||||
def __init__(self, mode="512train", p_drop=0.):
|
||||
super().__init__()
|
||||
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
||||
self.make_mask = MASK_MODES[mode]
|
||||
self.p_drop = p_drop
|
||||
|
||||
def __call__(self, sample):
|
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||
x = sample['jpg']
|
||||
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
||||
if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
|
||||
mask = np.ones_like(mask)
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask > 0.5] = 1
|
||||
mask = torch.from_numpy(mask[..., None])
|
||||
|
|
|
@ -122,7 +122,6 @@ class DDPM(pl.LightningModule):
|
|||
if self.ucg_training:
|
||||
self.ucg_prng = np.random.RandomState()
|
||||
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if exists(given_betas):
|
||||
|
@ -1603,7 +1602,9 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
|||
To disable finetuning mode, set finetune_keys to None
|
||||
"""
|
||||
def __init__(self,
|
||||
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight"),
|
||||
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
|
||||
"model_ema.diffusion_modelinput_blocks00weight"
|
||||
),
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
|
||||
|
@ -1653,7 +1654,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
|||
@torch.no_grad()
|
||||
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
|
||||
# note: restricted to non-trainable encoders currently
|
||||
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpaiting'
|
||||
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
|
||||
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
|
||||
force_c_encode=True, return_original_cond=True, bs=bs)
|
||||
|
||||
|
|
Loading…
Reference in a new issue