data utils

This commit is contained in:
Robin Rombach 2022-07-24 13:23:50 +02:00
parent 6997027a41
commit 76e2f4b739

View file

@ -1,3 +1,5 @@
import os
import numpy as np
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
@ -21,3 +23,18 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
@abstractmethod
def __iter__(self):
pass
class PRNGMixin(object):
"""
Adds a prng property which is a numpy RandomState which gets
reinitialized whenever the pid changes to avoid synchronized sampling
behavior when used in conjunction with multiprocessing.
"""
@property
def prng(self):
currentpid = os.getpid()
if getattr(self, "_initpid", None) != currentpid:
self._initpid = currentpid
self._prng = np.random.RandomState()
return self._prng