tweak helper scripts

This commit is contained in:
Ruben van de Ven 2023-01-12 15:54:18 +01:00
parent 35e9eb4193
commit 658e4c06f2
2 changed files with 11 additions and 3 deletions

View File

@ -90,7 +90,7 @@ def open_image_folder(source_dir, *, max_images: Optional[int]):
arch_fname = os.path.relpath(fname, source_dir) arch_fname = os.path.relpath(fname, source_dir)
arch_fname = arch_fname.replace('\\', '/') arch_fname = arch_fname.replace('\\', '/')
img = np.array(PIL.Image.open(fname)) img = np.array(PIL.Image.open(fname))
yield dict(img=img, label=labels.get(arch_fname)) yield dict(img=img, label=labels.get(arch_fname), filename=fname)
if idx >= max_idx-1: if idx >= max_idx-1:
break break
return max_idx, iterate_images() return max_idx, iterate_images()
@ -119,7 +119,7 @@ def open_image_zip(source, *, max_images: Optional[int]):
with z.open(fname, 'r') as file: with z.open(fname, 'r') as file:
img = PIL.Image.open(file) # type: ignore img = PIL.Image.open(file) # type: ignore
img = np.array(img) img = np.array(img)
yield dict(img=img, label=labels.get(fname)) yield dict(img=img, label=labels.get(fname), filename=fname)
if idx >= max_idx-1: if idx >= max_idx-1:
break break
return max_idx, iterate_images() return max_idx, iterate_images()

10
runs.py
View File

@ -1,13 +1,14 @@
import os import os
import datetime import datetime
import json import json
from typing import List from typing import List, Optional
from PIL import Image from PIL import Image
from enum import Enum from enum import Enum
import logging import logging
import numpy as np import numpy as np
import dnnlib import dnnlib
import legacy import legacy
from dataset_tool import open_dataset
logger = logging.getLogger('runs') logger = logging.getLogger('runs')
@ -131,6 +132,13 @@ class Run():
def dataset_is_conditional(self): def dataset_is_conditional(self):
return bool(self.training_options["training_set_kwargs"]["use_labels"]) return bool(self.training_options["training_set_kwargs"]["use_labels"])
def dataset_iterator(self, max_images: Optional[int] = None):
max_images, iterator = open_dataset(
self.training_options["training_set_kwargs"]["path"],
max_images=max_images
)
return iterator
@property @property
def resolution(self): def resolution(self):
return self.training_options["training_set_kwargs"]["resolution"] return self.training_options["training_set_kwargs"]["resolution"]