handle listconfig
This commit is contained in:
parent
b3a604d870
commit
fff19bf82e
1 changed files with 3 additions and 0 deletions
|
@ -17,6 +17,7 @@ from functools import partial
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||||
from ldm.modules.ema import LitEma
|
from ldm.modules.ema import LitEma
|
||||||
|
@ -1188,6 +1189,8 @@ class LatentDiffusion(DDPM):
|
||||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
if null_label is not None:
|
if null_label is not None:
|
||||||
xc = null_label
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
if isinstance(xc, dict) or isinstance(xc, list):
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
c = self.get_learned_conditioning(xc)
|
c = self.get_learned_conditioning(xc)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue