laion explorations
This commit is contained in:
parent
96c50fbe93
commit
d9c9747122
1 changed files with 68 additions and 2 deletions
|
@ -300,8 +300,7 @@ def example01():
|
|||
print("next epoch.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#example01()
|
||||
def example02():
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import IterableDataset
|
||||
|
@ -318,3 +317,70 @@ if __name__ == "__main__":
|
|||
print(batch.keys())
|
||||
print(batch["jpg"].shape)
|
||||
break
|
||||
|
||||
|
||||
def example03():
|
||||
# improved aesthetics
|
||||
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
||||
dataset = wds.WebDataset(tars)
|
||||
|
||||
def filter_keys(x):
|
||||
try:
|
||||
return ("jpg" in x) and ("txt" in x)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def filter_size(x):
|
||||
try:
|
||||
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
dataset = (dataset
|
||||
.select(filter_keys)
|
||||
.decode('pil', handler=wds.warn_and_continue))
|
||||
n_total = 0
|
||||
n_large = 0
|
||||
for i, example in enumerate(dataset):
|
||||
n_total += 1
|
||||
if filter_size(example):
|
||||
n_large += 1
|
||||
|
||||
if i%1000 == 0:
|
||||
print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
|
||||
|
||||
|
||||
|
||||
def example04():
|
||||
# improved aesthetics
|
||||
for i_shard in range(60208)[::-1]:
|
||||
print(i_shard)
|
||||
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
|
||||
dataset = wds.WebDataset(tars)
|
||||
|
||||
def filter_keys(x):
|
||||
try:
|
||||
return ("jpg" in x) and ("txt" in x)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def filter_size(x):
|
||||
try:
|
||||
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
dataset = (dataset
|
||||
.select(filter_keys)
|
||||
.decode('pil', handler=wds.warn_and_continue))
|
||||
try:
|
||||
example = next(iter(dataset))
|
||||
except Exception:
|
||||
print(f"Error @ {i_shard}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#example01()
|
||||
#example02()
|
||||
#example03()
|
||||
example04()
|
||||
|
|
Loading…
Reference in a new issue