laion explorations
This commit is contained in:
parent
96c50fbe93
commit
d9c9747122
1 changed files with 68 additions and 2 deletions
|
@ -300,8 +300,7 @@ def example01():
|
||||||
print("next epoch.")
|
print("next epoch.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def example02():
|
||||||
#example01()
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
@ -318,3 +317,70 @@ if __name__ == "__main__":
|
||||||
print(batch.keys())
|
print(batch.keys())
|
||||||
print(batch["jpg"].shape)
|
print(batch["jpg"].shape)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def example03():
|
||||||
|
# improved aesthetics
|
||||||
|
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
||||||
|
dataset = wds.WebDataset(tars)
|
||||||
|
|
||||||
|
def filter_keys(x):
|
||||||
|
try:
|
||||||
|
return ("jpg" in x) and ("txt" in x)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_size(x):
|
||||||
|
try:
|
||||||
|
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
dataset = (dataset
|
||||||
|
.select(filter_keys)
|
||||||
|
.decode('pil', handler=wds.warn_and_continue))
|
||||||
|
n_total = 0
|
||||||
|
n_large = 0
|
||||||
|
for i, example in enumerate(dataset):
|
||||||
|
n_total += 1
|
||||||
|
if filter_size(example):
|
||||||
|
n_large += 1
|
||||||
|
|
||||||
|
if i%1000 == 0:
|
||||||
|
print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def example04():
|
||||||
|
# improved aesthetics
|
||||||
|
for i_shard in range(60208)[::-1]:
|
||||||
|
print(i_shard)
|
||||||
|
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
|
||||||
|
dataset = wds.WebDataset(tars)
|
||||||
|
|
||||||
|
def filter_keys(x):
|
||||||
|
try:
|
||||||
|
return ("jpg" in x) and ("txt" in x)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_size(x):
|
||||||
|
try:
|
||||||
|
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
dataset = (dataset
|
||||||
|
.select(filter_keys)
|
||||||
|
.decode('pil', handler=wds.warn_and_continue))
|
||||||
|
try:
|
||||||
|
example = next(iter(dataset))
|
||||||
|
except Exception:
|
||||||
|
print(f"Error @ {i_shard}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#example01()
|
||||||
|
#example02()
|
||||||
|
#example03()
|
||||||
|
example04()
|
||||||
|
|
Loading…
Reference in a new issue