2022-05-26 21:51:29 +00:00
import webdataset as wds
from PIL import Image
import io
import os
2022-05-27 09:46:04 +00:00
import torchvision
from PIL import Image
import glob
import random
import numpy as np
import pytorch_lightning as pl
2022-05-26 21:51:29 +00:00
from tqdm import tqdm
2022-05-27 09:46:04 +00:00
from omegaconf import OmegaConf
from einops import rearrange
import torch
from ldm . util import instantiate_from_config
def dict_collation_fn ( samples , combine_tensors = True , combine_scalars = True ) :
""" Take a list of samples (as dictionary) and create a batch, preserving the keys.
If ` tensors ` is True , ` ndarray ` objects are combined into
tensor batches .
: param dict samples : list of samples
: param bool tensors : whether to turn lists of ndarrays into a single ndarray
: returns : single sample consisting of a batch
: rtype : dict
"""
batched = { key : [ ] for key in samples [ 0 ] }
# assert isinstance(samples[0][first_key], (list, tuple)), type(samples[first_key])
for s in samples :
[ batched [ key ] . append ( s [ key ] ) for key in batched ]
result = { }
for key in batched :
if isinstance ( batched [ key ] [ 0 ] , ( int , float ) ) :
if combine_scalars :
result [ key ] = np . array ( list ( batched [ key ] ) )
elif isinstance ( batched [ key ] [ 0 ] , torch . Tensor ) :
if combine_tensors :
# import torch
result [ key ] = torch . stack ( list ( batched [ key ] ) )
elif isinstance ( batched [ key ] [ 0 ] , np . ndarray ) :
if combine_tensors :
result [ key ] = np . array ( list ( batched [ key ] ) )
else :
result [ key ] = list ( batched [ key ] )
# result.append(b)
return result
class WebDataModuleFromConfig ( pl . LightningDataModule ) :
def __init__ ( self , tar_base , batch_size , train = None , validation = None ,
test = None , num_workers = 4 , load_ddp = True , n_nodes = 1 ,
* * kwargs ) :
super ( ) . __init__ ( self )
print ( f ' Setting tar base to { tar_base } ' )
self . tar_base = tar_base
self . batch_size = batch_size
self . num_workers = num_workers
self . train = train
self . validation = validation
self . test = test
self . load_ddp = load_ddp
self . multinode = n_nodes > 1
self . n_nodes = n_nodes # n gpu ??
def make_loader ( self , dataset_config , train = True ) :
if ' image_transforms ' in dataset_config :
image_transforms = [ instantiate_from_config ( tt ) for tt in dataset_config . image_transforms ]
else :
image_transforms = [ ]
image_transforms . extend ( [ torchvision . transforms . ToTensor ( ) ,
torchvision . transforms . Lambda ( lambda x : rearrange ( x * 2. - 1. , ' c h w -> h w c ' ) ) ] )
image_transforms = torchvision . transforms . Compose ( image_transforms )
if ' transforms ' in dataset_config :
transforms_config = OmegaConf . to_container ( dataset_config . transforms )
else :
transforms_config = dict ( )
transform_dict = { dkey : load_partial_from_config ( transforms_config [ dkey ] ) if transforms_config [
dkey ] != ' identity ' else identity
for dkey in transforms_config }
img_key = dataset_config . get ( ' image_key ' , ' jpeg ' )
transform_dict . update ( { img_key : image_transforms } )
shuffle = dataset_config . get ( ' shuffle ' , 0 )
# TODO fid strategy when n exmples not known beforehand
n_examples = dataset_config . get ( ' n_examples ' , 1e6 ) / / self . n_nodes
shards_to_load = dataset_config . shards
dset_name = ' unknown '
if isinstance ( shards_to_load , str ) :
print ( f ' Loading tars based on the string { shards_to_load } ' )
tars = os . path . join ( self . tar_base , shards_to_load )
start_shard_id , end_shard_id = dataset_config . shards . split ( ' { ' ) [ - 1 ] . split ( ' } ' ) [ 0 ] . split ( ' .. ' )
n_shards = int ( end_shard_id ) - int ( start_shard_id ) + 1
dset_name = dataset_config . shards . split ( ' - ' ) [ 0 ]
elif isinstance ( shards_to_load , int ) :
print ( f ' Creating tar list, max shard is { shards_to_load } ' )
try :
tars = [ tf for tf in natsorted ( glob ( os . path . join ( self . tar_base , ' *.tar ' ) ) ) if
int ( tf . split ( ' / ' ) [ - 1 ] . split ( ' . ' ) [ 0 ] ) < shards_to_load ]
n_shards = len ( tars )
random . shuffle ( tars )
except ValueError as e :
print ( ' tarfile names should follow the pattern <zero_padded_number>.tar . Check names of the files ' )
raise e
else :
raise ValueError (
' shards should be either a string containing consecutive shards or an int defining the max shard number ' )
print ( f ' Got { n_shards } shard files in datafolder for { " training " if train else " validation " } ' )
# if self.num_workers > 0:
# assert n_shards % self.num_workers == 0 , f'Number of workers which is {self.num_workers} does not evenly divide number of shards which is {n_shards}'
print ( f ' Loading webdataset based dataloader based on { n_shards } of { dset_name } dataset. ' )
# start creating the dataset
nodesplitter = wds . shardlists . split_by_node if self . multinode else wds . shardlists . single_node_only
epoch_length = n_examples / / ( self . batch_size )
dset = wds . WebDataset ( tars , nodesplitter = nodesplitter ) . shuffle ( shuffle )
with_epoch_args = { ' nsamples ' : n_examples , ' nbatches ' : epoch_length }
if ' filters ' in dataset_config :
for stage in tqdm ( dataset_config . filters ,
desc = f ' Applying the following filters: { [ f for f in dataset_config . filters ] } ' ) :
f = getattr ( dset , stage )
dset = f ( dset , * dataset_config . filters [ stage ] . args ,
* * dataset_config . filters [ stage ] . get ( ' kwargs ' , dict ( ) ) )
print ( f ' Dataset holding { len ( dset . pipeline [ 0 ] . urls ) } shards ' )
2022-05-30 14:09:27 +00:00
from webdataset . handlers import warn_and_continue
2022-05-30 14:04:35 +00:00
2022-05-27 09:46:04 +00:00
dset = ( dset
2022-05-30 14:09:27 +00:00
. decode ( ' pil ' , handler = warn_and_continue )
2022-05-27 09:46:04 +00:00
# .to_tuple("jpg;png;jpeg pickle cls hls")
# .map_tuple(image_transforms,load_partial_from_config(nns_transform) if 'target' in nns_transform else identity,identity,identity)
. map_dict ( * * transform_dict )
. repeat ( )
. batched ( self . batch_size , partial = False ,
collation_fn = dict_collation_fn )
. with_length ( n_examples )
. with_epoch ( * * with_epoch_args )
)
loader = wds . WebLoader ( dset , batch_size = None , shuffle = False ,
num_workers = self . num_workers )
return loader , n_examples
def train_dataloader ( self ) :
assert self . train is not None
loader , dset_size = self . make_loader ( self . train )
# if self.load_ddp:
# loader = loader.ddp_equalize(dset_size // self.batch_size)
return loader
def val_dataloader ( self ) :
assert self . train is not None
loader , _ = self . make_loader ( self . validation , train = False )
return loader
def test_dataloader ( self ) :
assert self . train is not None
loader , _ = self . make_loader ( self . test , train = False )
return loader
2022-05-26 21:51:29 +00:00
if __name__ == " __main__ " :
url = " pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar - "
dataset = wds . WebDataset ( url )
example = next ( iter ( dataset ) )
for k in example :
print ( k , type ( example [ k ] ) )
print ( example [ " __key__ " ] )
for k in [ " json " , " txt " ] :
print ( example [ k ] . decode ( ) )
image = Image . open ( io . BytesIO ( example [ " jpg " ] ) )
outdir = " tmp "
os . makedirs ( outdir , exist_ok = True )
2022-05-27 09:46:04 +00:00
image . save ( os . path . join ( outdir , example [ " __key__ " ] + " .png " ) )
2022-05-26 21:51:29 +00:00
def load_example ( example ) :
return {
" key " : example [ " __key__ " ] ,
" image " : Image . open ( io . BytesIO ( example [ " jpg " ] ) ) ,
" text " : example [ " txt " ] . decode ( ) ,
}
for i , example in tqdm ( enumerate ( dataset ) ) :
ex = load_example ( example )
print ( ex [ " image " ] . size , ex [ " text " ] )
if i > = 100 :
break