Update ddpm.py
clean up no.1
This commit is contained in:
parent
417d5f3dee
commit
2b46bcb98c
1 changed files with 4 additions and 4 deletions
|
@ -461,7 +461,7 @@ class LatentDiffusion(DDPM):
|
||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None # # TODO: special class?
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
|
@ -598,7 +598,7 @@ class LatentDiffusion(DDPM):
|
||||||
weighting = weighting * L_weighting
|
weighting = weighting * L_weighting
|
||||||
return weighting
|
return weighting
|
||||||
|
|
||||||
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code !
|
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
||||||
"""
|
"""
|
||||||
:param x: img of size (bs, c, h, w)
|
:param x: img of size (bs, c, h, w)
|
||||||
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
||||||
|
@ -793,7 +793,7 @@ class LatentDiffusion(DDPM):
|
||||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||||
|
|
||||||
# 2. apply model loop over last dim
|
# 2. apply model loop over last dim
|
||||||
if isinstance(self.first_stage_model, VQModelInterface): # todo ask what this is
|
if isinstance(self.first_stage_model, VQModelInterface):
|
||||||
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
||||||
force_not_quantize=predict_cids or force_not_quantize)
|
force_not_quantize=predict_cids or force_not_quantize)
|
||||||
for i in range(z.shape[-1])]
|
for i in range(z.shape[-1])]
|
||||||
|
@ -901,7 +901,7 @@ class LatentDiffusion(DDPM):
|
||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids # todo dont know what this is -> I exclude --> Good
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue