40 lines
1.2 KiB
Python
40 lines
1.2 KiB
Python
import os
|
|
import numpy as np
|
|
from abc import abstractmethod
|
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
|
|
|
|
|
class Txt2ImgIterableBaseDataset(IterableDataset):
|
|
'''
|
|
Define an interface to make the IterableDatasets for text2img data chainable
|
|
'''
|
|
def __init__(self, num_records=0, valid_ids=None, size=256):
|
|
super().__init__()
|
|
self.num_records = num_records
|
|
self.valid_ids = valid_ids
|
|
self.sample_ids = valid_ids
|
|
self.size = size
|
|
|
|
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
|
|
|
def __len__(self):
|
|
return self.num_records
|
|
|
|
@abstractmethod
|
|
def __iter__(self):
|
|
pass
|
|
|
|
|
|
class PRNGMixin(object):
|
|
"""
|
|
Adds a prng property which is a numpy RandomState which gets
|
|
reinitialized whenever the pid changes to avoid synchronized sampling
|
|
behavior when used in conjunction with multiprocessing.
|
|
"""
|
|
@property
|
|
def prng(self):
|
|
currentpid = os.getpid()
|
|
if getattr(self, "_initpid", None) != currentpid:
|
|
self._initpid = currentpid
|
|
self._prng = np.random.RandomState()
|
|
return self._prng
|