add alstroemeria data
This commit is contained in:
parent
1dba1b17ee
commit
1eb0b88a8d
1 changed files with 53 additions and 0 deletions
|
@ -18,6 +18,59 @@ from webdataset.handlers import warn_and_continue
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class DataWithWings(data.IterableDataset):
|
||||||
|
def __init__(self, min_size, transform=None, target_transform=None):
|
||||||
|
self.min_size = min_size
|
||||||
|
self.transform = transform if transform is not None else nn.Identity()
|
||||||
|
self.target_transform = target_transform if target_transform is not None else nn.Identity()
|
||||||
|
self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
|
||||||
|
self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
|
||||||
|
self.pwatermark_threshold = 0.8
|
||||||
|
self.punsafe_threshold = 0.5
|
||||||
|
self.aesthetic_threshold = 5.
|
||||||
|
self.total_samples = 0
|
||||||
|
self.samples = 0
|
||||||
|
location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
|
||||||
|
|
||||||
|
self.inner_dataset = wds.DataPipeline(
|
||||||
|
wds.ResampledShards(location),
|
||||||
|
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||||
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
|
wds.decode('pilrgb', handler=wds.warn_and_continue),
|
||||||
|
wds.map(self._add_tags, handler=wds.ignore_and_continue),
|
||||||
|
wds.select(self._filter_predicate),
|
||||||
|
wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
|
||||||
|
wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_hash(url, text):
|
||||||
|
if url is None:
|
||||||
|
url = ''
|
||||||
|
if text is None:
|
||||||
|
text = ''
|
||||||
|
total = (url + text).encode('utf-8')
|
||||||
|
return mmh3.hash64(total)[0]
|
||||||
|
|
||||||
|
def _add_tags(self, x):
|
||||||
|
hsh = self._compute_hash(x['json']['url'], x['txt'])
|
||||||
|
pwatermark, punsafe = self.kv[hsh]
|
||||||
|
aesthetic = self.kv_aesthetic[hsh][0]
|
||||||
|
return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
|
||||||
|
|
||||||
|
def _punsafe_to_class(self, punsafe):
|
||||||
|
return torch.tensor(punsafe >= self.punsafe_threshold).long()
|
||||||
|
|
||||||
|
def _filter_predicate(self, x):
|
||||||
|
try:
|
||||||
|
return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.inner_dataset)
|
||||||
|
|
||||||
|
|
||||||
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||||
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
|
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
|
||||||
If `tensors` is True, `ndarray` objects are combined into
|
If `tensors` is True, `ndarray` objects are combined into
|
||||||
|
|
Loading…
Reference in a new issue