tweak helper scripts
This commit is contained in:
parent
35e9eb4193
commit
658e4c06f2
2 changed files with 11 additions and 3 deletions
|
@ -90,7 +90,7 @@ def open_image_folder(source_dir, *, max_images: Optional[int]):
|
||||||
arch_fname = os.path.relpath(fname, source_dir)
|
arch_fname = os.path.relpath(fname, source_dir)
|
||||||
arch_fname = arch_fname.replace('\\', '/')
|
arch_fname = arch_fname.replace('\\', '/')
|
||||||
img = np.array(PIL.Image.open(fname))
|
img = np.array(PIL.Image.open(fname))
|
||||||
yield dict(img=img, label=labels.get(arch_fname))
|
yield dict(img=img, label=labels.get(arch_fname), filename=fname)
|
||||||
if idx >= max_idx-1:
|
if idx >= max_idx-1:
|
||||||
break
|
break
|
||||||
return max_idx, iterate_images()
|
return max_idx, iterate_images()
|
||||||
|
@ -119,7 +119,7 @@ def open_image_zip(source, *, max_images: Optional[int]):
|
||||||
with z.open(fname, 'r') as file:
|
with z.open(fname, 'r') as file:
|
||||||
img = PIL.Image.open(file) # type: ignore
|
img = PIL.Image.open(file) # type: ignore
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
yield dict(img=img, label=labels.get(fname))
|
yield dict(img=img, label=labels.get(fname), filename=fname)
|
||||||
if idx >= max_idx-1:
|
if idx >= max_idx-1:
|
||||||
break
|
break
|
||||||
return max_idx, iterate_images()
|
return max_idx, iterate_images()
|
||||||
|
|
10
runs.py
10
runs.py
|
@ -1,13 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import dnnlib
|
import dnnlib
|
||||||
import legacy
|
import legacy
|
||||||
|
from dataset_tool import open_dataset
|
||||||
|
|
||||||
logger = logging.getLogger('runs')
|
logger = logging.getLogger('runs')
|
||||||
|
|
||||||
|
@ -131,6 +132,13 @@ class Run():
|
||||||
def dataset_is_conditional(self):
|
def dataset_is_conditional(self):
|
||||||
return bool(self.training_options["training_set_kwargs"]["use_labels"])
|
return bool(self.training_options["training_set_kwargs"]["use_labels"])
|
||||||
|
|
||||||
|
def dataset_iterator(self, max_images: Optional[int] = None):
|
||||||
|
max_images, iterator = open_dataset(
|
||||||
|
self.training_options["training_set_kwargs"]["path"],
|
||||||
|
max_images=max_images
|
||||||
|
)
|
||||||
|
return iterator
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resolution(self):
|
def resolution(self):
|
||||||
return self.training_options["training_set_kwargs"]["resolution"]
|
return self.training_options["training_set_kwargs"]["resolution"]
|
||||||
|
|
Loading…
Reference in a new issue