handle listconfig

This commit is contained in:
rromb 2022-06-01 09:52:17 +02:00
parent b3a604d870
commit fff19bf82e

View file

@ -17,6 +17,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
@ -1188,6 +1189,8 @@ class LatentDiffusion(DDPM):
def get_unconditional_conditioning(self, batch_size, null_label=None): def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None: if null_label is not None:
xc = null_label xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list): if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc) c = self.get_learned_conditioning(xc)
else: else: