diff --git a/ldm/data/laion.py b/ldm/data/laion.py index bfc7d3c..3593545 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -18,6 +18,59 @@ from webdataset.handlers import warn_and_continue from ldm.util import instantiate_from_config +class DataWithWings(data.IterableDataset): + def __init__(self, min_size, transform=None, target_transform=None): + self.min_size = min_size + self.transform = transform if transform is not None else nn.Identity() + self.target_transform = target_transform if target_transform is not None else nn.Identity() + self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee') + self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e') + self.pwatermark_threshold = 0.8 + self.punsafe_threshold = 0.5 + self.aesthetic_threshold = 5. + self.total_samples = 0 + self.samples = 0 + location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -' + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode('pilrgb', handler=wds.warn_and_continue), + wds.map(self._add_tags, handler=wds.ignore_and_continue), + wds.select(self._filter_predicate), + wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue), + wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue), + ) + + @staticmethod + def _compute_hash(url, text): + if url is None: + url = '' + if text is None: + text = '' + total = (url + text).encode('utf-8') + return mmh3.hash64(total)[0] + + def _add_tags(self, x): + hsh = self._compute_hash(x['json']['url'], x['txt']) + pwatermark, punsafe = self.kv[hsh] + aesthetic = self.kv_aesthetic[hsh][0] + return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic} + + def _punsafe_to_class(self, punsafe): + return torch.tensor(punsafe >= self.punsafe_threshold).long() + + def _filter_predicate(self, x): + try: + return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size + except: + return False + + def __iter__(self): + return iter(self.inner_dataset) + + def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): """Take a list of samples (as dictionary) and create a batch, preserving the keys. If `tensors` is True, `ndarray` objects are combined into