tweak helper scripts

This commit is contained in:
Ruben van de Ven 2023-01-12 15:54:18 +01:00
parent 35e9eb4193
commit 658e4c06f2
2 changed files with 11 additions and 3 deletions

View file

@ -90,7 +90,7 @@ def open_image_folder(source_dir, *, max_images: Optional[int]):
arch_fname = os.path.relpath(fname, source_dir)
arch_fname = arch_fname.replace('\\', '/')
img = np.array(PIL.Image.open(fname))
yield dict(img=img, label=labels.get(arch_fname))
yield dict(img=img, label=labels.get(arch_fname), filename=fname)
if idx >= max_idx-1:
break
return max_idx, iterate_images()
@ -119,7 +119,7 @@ def open_image_zip(source, *, max_images: Optional[int]):
with z.open(fname, 'r') as file:
img = PIL.Image.open(file) # type: ignore
img = np.array(img)
yield dict(img=img, label=labels.get(fname))
yield dict(img=img, label=labels.get(fname), filename=fname)
if idx >= max_idx-1:
break
return max_idx, iterate_images()

10
runs.py
View file

@ -1,13 +1,14 @@
import os
import datetime
import json
from typing import List
from typing import List, Optional
from PIL import Image
from enum import Enum
import logging
import numpy as np
import dnnlib
import legacy
from dataset_tool import open_dataset
logger = logging.getLogger('runs')
@ -131,6 +132,13 @@ class Run():
def dataset_is_conditional(self):
return bool(self.training_options["training_set_kwargs"]["use_labels"])
def dataset_iterator(self, max_images: Optional[int] = None):
max_images, iterator = open_dataset(
self.training_options["training_set_kwargs"]["path"],
max_images=max_images
)
return iterator
@property
def resolution(self):
return self.training_options["training_set_kwargs"]["resolution"]