diff --git a/dataset_tool.py b/dataset_tool.py index 5a06eb4..93481b0 100644 --- a/dataset_tool.py +++ b/dataset_tool.py @@ -90,7 +90,7 @@ def open_image_folder(source_dir, *, max_images: Optional[int]): arch_fname = os.path.relpath(fname, source_dir) arch_fname = arch_fname.replace('\\', '/') img = np.array(PIL.Image.open(fname)) - yield dict(img=img, label=labels.get(arch_fname)) + yield dict(img=img, label=labels.get(arch_fname), filename=fname) if idx >= max_idx-1: break return max_idx, iterate_images() @@ -119,7 +119,7 @@ def open_image_zip(source, *, max_images: Optional[int]): with z.open(fname, 'r') as file: img = PIL.Image.open(file) # type: ignore img = np.array(img) - yield dict(img=img, label=labels.get(fname)) + yield dict(img=img, label=labels.get(fname), filename=fname) if idx >= max_idx-1: break return max_idx, iterate_images() diff --git a/runs.py b/runs.py index d66cec7..42c7b12 100644 --- a/runs.py +++ b/runs.py @@ -1,13 +1,14 @@ import os import datetime import json -from typing import List +from typing import List, Optional from PIL import Image from enum import Enum import logging import numpy as np import dnnlib import legacy +from dataset_tool import open_dataset logger = logging.getLogger('runs') @@ -131,6 +132,13 @@ class Run(): def dataset_is_conditional(self): return bool(self.training_options["training_set_kwargs"]["use_labels"]) + def dataset_iterator(self, max_images: Optional[int] = None): + max_images, iterator = open_dataset( + self.training_options["training_set_kwargs"]["path"], + max_images=max_images + ) + return iterator + @property def resolution(self): return self.training_options["training_set_kwargs"]["resolution"]