add alstroemeria data

This commit is contained in:
rromb 2022-05-31 15:20:27 +02:00
parent 1dba1b17ee
commit 1eb0b88a8d

View file

@ -18,6 +18,59 @@ from webdataset.handlers import warn_and_continue
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
class DataWithWings(data.IterableDataset):
def __init__(self, min_size, transform=None, target_transform=None):
self.min_size = min_size
self.transform = transform if transform is not None else nn.Identity()
self.target_transform = target_transform if target_transform is not None else nn.Identity()
self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
self.pwatermark_threshold = 0.8
self.punsafe_threshold = 0.5
self.aesthetic_threshold = 5.
self.total_samples = 0
self.samples = 0
location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode('pilrgb', handler=wds.warn_and_continue),
wds.map(self._add_tags, handler=wds.ignore_and_continue),
wds.select(self._filter_predicate),
wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
)
@staticmethod
def _compute_hash(url, text):
if url is None:
url = ''
if text is None:
text = ''
total = (url + text).encode('utf-8')
return mmh3.hash64(total)[0]
def _add_tags(self, x):
hsh = self._compute_hash(x['json']['url'], x['txt'])
pwatermark, punsafe = self.kv[hsh]
aesthetic = self.kv_aesthetic[hsh][0]
return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
def _punsafe_to_class(self, punsafe):
return torch.tensor(punsafe >= self.punsafe_threshold).long()
def _filter_predicate(self, x):
try:
return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
except:
return False
def __iter__(self):
return iter(self.inner_dataset)
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""Take a list of samples (as dictionary) and create a batch, preserving the keys. """Take a list of samples (as dictionary) and create a batch, preserving the keys.
If `tensors` is True, `ndarray` objects are combined into If `tensors` is True, `ndarray` objects are combined into