first (in)stable steps

This commit is contained in:
rromb 2022-05-27 11:46:04 +02:00
parent f7a6152022
commit 9a419a1b14
4 changed files with 310 additions and 66 deletions

View file

@ -0,0 +1,130 @@
model:
base_learning_rate: 1.0e-04 # TODO: run with scale_lr False
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 32
channels: 4
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 128 # 320 # TODO increase
attention_resolutions: [ 4, 2, 1 ] # is equal to fixed spatial resolution: 32 , 16 , 8
num_res_blocks: 2
channel_mult: [ 1,2,4,4 ]
#num_head_channels: 32
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 1280
use_checkpoint: True
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ckpt_path: "/home/robin/projects/latent-diffusion/models/first_stage_models/kl-f8/model.ckpt"
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.BERTEmbedder
params:
n_embed: 1280
n_layer: 3 #32 # TODO: increase
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
batch_size: 60
num_workers: 4
n_nodes: 2 # TODO: runs with two gpus
train:
shards: '{000000..000010}.tar -' # TODO: wild guess, change
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
shuffle: 5000
n_examples: 16519100 # TODO: find out
validation:
shards: '{000011..000012}.tar -' # TODO: wild guess, change
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
shuffle: 0
n_examples: 60000 # TODO: find out
val_num_workers: 2
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000 # 5000
max_images: 8
increase_log_steps: False
log_first_step: True
trainer:
replace_sampler_ddp: False
benchmark: True
val_check_interval: 20000 # every 20k training steps
num_sanity_val_steps: 0

View file

@ -2,7 +2,178 @@ import webdataset as wds
from PIL import Image from PIL import Image
import io import io
import os import os
import torchvision
from PIL import Image
import glob
import random
import numpy as np
import pytorch_lightning as pl
from tqdm import tqdm from tqdm import tqdm
from omegaconf import OmegaConf
from einops import rearrange
import torch
from ldm.util import instantiate_from_config
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
If `tensors` is True, `ndarray` objects are combined into
tensor batches.
:param dict samples: list of samples
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
:returns: single sample consisting of a batch
:rtype: dict
"""
batched = {key: [] for key in samples[0]}
# assert isinstance(samples[0][first_key], (list, tuple)), type(samples[first_key])
for s in samples:
[batched[key].append(s[key]) for key in batched]
result = {}
for key in batched:
if isinstance(batched[key][0], (int, float)):
if combine_scalars:
result[key] = np.array(list(batched[key]))
elif isinstance(batched[key][0], torch.Tensor):
if combine_tensors:
# import torch
result[key] = torch.stack(list(batched[key]))
elif isinstance(batched[key][0], np.ndarray):
if combine_tensors:
result[key] = np.array(list(batched[key]))
else:
result[key] = list(batched[key])
# result.append(b)
return result
class WebDataModuleFromConfig(pl.LightningDataModule):
def __init__(self, tar_base, batch_size, train=None, validation=None,
test=None, num_workers=4, load_ddp=True, n_nodes=1,
**kwargs):
super().__init__(self)
print(f'Setting tar base to {tar_base}')
self.tar_base = tar_base
self.batch_size = batch_size
self.num_workers = num_workers
self.train = train
self.validation = validation
self.test = test
self.load_ddp = load_ddp
self.multinode = n_nodes > 1
self.n_nodes = n_nodes # n gpu ??
def make_loader(self, dataset_config, train=True):
if 'image_transforms' in dataset_config:
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
else:
image_transforms = []
image_transforms.extend([torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms = torchvision.transforms.Compose(image_transforms)
if 'transforms' in dataset_config:
transforms_config = OmegaConf.to_container(dataset_config.transforms)
else:
transforms_config = dict()
transform_dict = {dkey: load_partial_from_config(transforms_config[dkey]) if transforms_config[
dkey] != 'identity' else identity
for dkey in transforms_config}
img_key = dataset_config.get('image_key', 'jpeg')
transform_dict.update({img_key: image_transforms})
shuffle = dataset_config.get('shuffle', 0)
# TODO fid strategy when n exmples not known beforehand
n_examples = dataset_config.get('n_examples', 1e6) // self.n_nodes
shards_to_load = dataset_config.shards
dset_name = 'unknown'
if isinstance(shards_to_load, str):
print(f'Loading tars based on the string {shards_to_load}')
tars = os.path.join(self.tar_base, shards_to_load)
start_shard_id, end_shard_id = dataset_config.shards.split('{')[-1].split('}')[0].split('..')
n_shards = int(end_shard_id) - int(start_shard_id) + 1
dset_name = dataset_config.shards.split('-')[0]
elif isinstance(shards_to_load, int):
print(f'Creating tar list, max shard is {shards_to_load}')
try:
tars = [tf for tf in natsorted(glob(os.path.join(self.tar_base, '*.tar'))) if
int(tf.split('/')[-1].split('.')[0]) < shards_to_load]
n_shards = len(tars)
random.shuffle(tars)
except ValueError as e:
print('tarfile names should follow the pattern <zero_padded_number>.tar . Check names of the files')
raise e
else:
raise ValueError(
'shards should be either a string containing consecutive shards or an int defining the max shard number')
print(f'Got {n_shards} shard files in datafolder for {"training" if train else "validation"}')
# if self.num_workers > 0:
# assert n_shards % self.num_workers == 0 , f'Number of workers which is {self.num_workers} does not evenly divide number of shards which is {n_shards}'
print(f'Loading webdataset based dataloader based on {n_shards} of {dset_name} dataset.')
# start creating the dataset
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
epoch_length = n_examples // (self.batch_size)
dset = wds.WebDataset(tars, nodesplitter=nodesplitter).shuffle(shuffle)
with_epoch_args = {'nsamples': n_examples, 'nbatches': epoch_length}
if 'filters' in dataset_config:
for stage in tqdm(dataset_config.filters,
desc=f'Applying the following filters: {[f for f in dataset_config.filters]}'):
f = getattr(dset, stage)
dset = f(dset, *dataset_config.filters[stage].args,
**dataset_config.filters[stage].get('kwargs', dict()))
print(f'Dataset holding {len(dset.pipeline[0].urls)} shards')
dset = (dset
.decode('pil')
# .to_tuple("jpg;png;jpeg pickle cls hls")
# .map_tuple(image_transforms,load_partial_from_config(nns_transform) if 'target' in nns_transform else identity,identity,identity)
.map_dict(**transform_dict)
.repeat()
.batched(self.batch_size, partial=False,
collation_fn=dict_collation_fn)
.with_length(n_examples)
.with_epoch(**with_epoch_args)
)
loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
num_workers=self.num_workers)
return loader, n_examples
def train_dataloader(self):
assert self.train is not None
loader, dset_size = self.make_loader(self.train)
# if self.load_ddp:
# loader = loader.ddp_equalize(dset_size // self.batch_size)
return loader
def val_dataloader(self):
assert self.train is not None
loader, _ = self.make_loader(self.validation, train=False)
return loader
def test_dataloader(self):
assert self.train is not None
loader, _ = self.make_loader(self.test, train=False)
return loader
if __name__ == "__main__": if __name__ == "__main__":
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"

View file

@ -664,7 +664,7 @@ class LatentDiffusion(DDPM):
if cond_key is None: if cond_key is None:
cond_key = self.cond_stage_key cond_key = self.cond_stage_key
if cond_key != self.first_stage_key: if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox']: if cond_key in ['caption', 'coordinates_bbox', "txt"]:
xc = batch[cond_key] xc = batch[cond_key]
elif cond_key == 'class_label': elif cond_key == 'class_label':
xc = batch xc = batch
@ -762,66 +762,6 @@ class LatentDiffusion(DDPM):
else: else:
return self.first_stage_model.decode(z) return self.first_stage_model.decode(z)
# same as above but without decorator
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
else:
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
@torch.no_grad() @torch.no_grad()
def encode_first_stage(self, x): def encode_first_stage(self, x):
if hasattr(self, "split_input_params"): if hasattr(self, "split_input_params"):
@ -1268,8 +1208,8 @@ class LatentDiffusion(DDPM):
if hasattr(self.cond_stage_model, "decode"): if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c) xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc log["conditioning"] = xc
elif self.cond_stage_key in ["caption"]: elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key])
log["conditioning"] = xc log["conditioning"] = xc
elif self.cond_stage_key == 'class_label': elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])

View file

@ -667,8 +667,11 @@ if __name__ == "__main__":
data.prepare_data() data.prepare_data()
data.setup() data.setup()
print("#### Data #####") print("#### Data #####")
try:
for k in data.datasets: for k in data.datasets:
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
except:
print("datasets not yet initialized.")
# configure learning rate # configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate