tweak helper scripts
This commit is contained in:
parent
35e9eb4193
commit
658e4c06f2
2 changed files with 11 additions and 3 deletions
|
@ -90,7 +90,7 @@ def open_image_folder(source_dir, *, max_images: Optional[int]):
|
|||
arch_fname = os.path.relpath(fname, source_dir)
|
||||
arch_fname = arch_fname.replace('\\', '/')
|
||||
img = np.array(PIL.Image.open(fname))
|
||||
yield dict(img=img, label=labels.get(arch_fname))
|
||||
yield dict(img=img, label=labels.get(arch_fname), filename=fname)
|
||||
if idx >= max_idx-1:
|
||||
break
|
||||
return max_idx, iterate_images()
|
||||
|
@ -119,7 +119,7 @@ def open_image_zip(source, *, max_images: Optional[int]):
|
|||
with z.open(fname, 'r') as file:
|
||||
img = PIL.Image.open(file) # type: ignore
|
||||
img = np.array(img)
|
||||
yield dict(img=img, label=labels.get(fname))
|
||||
yield dict(img=img, label=labels.get(fname), filename=fname)
|
||||
if idx >= max_idx-1:
|
||||
break
|
||||
return max_idx, iterate_images()
|
||||
|
|
10
runs.py
10
runs.py
|
@ -1,13 +1,14 @@
|
|||
import os
|
||||
import datetime
|
||||
import json
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
import logging
|
||||
import numpy as np
|
||||
import dnnlib
|
||||
import legacy
|
||||
from dataset_tool import open_dataset
|
||||
|
||||
logger = logging.getLogger('runs')
|
||||
|
||||
|
@ -131,6 +132,13 @@ class Run():
|
|||
def dataset_is_conditional(self):
|
||||
return bool(self.training_options["training_set_kwargs"]["use_labels"])
|
||||
|
||||
def dataset_iterator(self, max_images: Optional[int] = None):
|
||||
max_images, iterator = open_dataset(
|
||||
self.training_options["training_set_kwargs"]["path"],
|
||||
max_images=max_images
|
||||
)
|
||||
return iterator
|
||||
|
||||
@property
|
||||
def resolution(self):
|
||||
return self.training_options["training_set_kwargs"]["resolution"]
|
||||
|
|
Loading…
Reference in a new issue