diff --git a/ldm/data/base.py b/ldm/data/base.py index b196c2f..742794e 100644 --- a/ldm/data/base.py +++ b/ldm/data/base.py @@ -1,3 +1,5 @@ +import os +import numpy as np from abc import abstractmethod from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset @@ -20,4 +22,19 @@ class Txt2ImgIterableBaseDataset(IterableDataset): @abstractmethod def __iter__(self): - pass \ No newline at end of file + pass + + +class PRNGMixin(object): + """ + Adds a prng property which is a numpy RandomState which gets + reinitialized whenever the pid changes to avoid synchronized sampling + behavior when used in conjunction with multiprocessing. + """ + @property + def prng(self): + currentpid = os.getpid() + if getattr(self, "_initpid", None) != currentpid: + self._initpid = currentpid + self._prng = np.random.RandomState() + return self._prng