data utils
This commit is contained in:
parent
6997027a41
commit
76e2f4b739
1 changed files with 18 additions and 1 deletions
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||||
|
|
||||||
|
@ -21,3 +23,18 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PRNGMixin(object):
|
||||||
|
"""
|
||||||
|
Adds a prng property which is a numpy RandomState which gets
|
||||||
|
reinitialized whenever the pid changes to avoid synchronized sampling
|
||||||
|
behavior when used in conjunction with multiprocessing.
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def prng(self):
|
||||||
|
currentpid = os.getpid()
|
||||||
|
if getattr(self, "_initpid", None) != currentpid:
|
||||||
|
self._initpid = currentpid
|
||||||
|
self._prng = np.random.RandomState()
|
||||||
|
return self._prng
|
||||||
|
|
Loading…
Reference in a new issue