add upscaling
This commit is contained in:
		
							parent
							
								
									a0674ac4a2
								
							
						
					
					
						commit
						c5a39aff8a
					
				
					 5 changed files with 282 additions and 4 deletions
				
			
		| 
						 | 
					@ -0,0 +1,175 @@
 | 
				
			||||||
 | 
					model:
 | 
				
			||||||
 | 
					  base_learning_rate: 1.0e-04
 | 
				
			||||||
 | 
					  target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
 | 
				
			||||||
 | 
					  params:
 | 
				
			||||||
 | 
					    low_scale_key: "LR_image" # TODO: adapt
 | 
				
			||||||
 | 
					    linear_start: 0.001
 | 
				
			||||||
 | 
					    linear_end: 0.015
 | 
				
			||||||
 | 
					    num_timesteps_cond: 1
 | 
				
			||||||
 | 
					    log_every_t: 200
 | 
				
			||||||
 | 
					    timesteps: 1000
 | 
				
			||||||
 | 
					    first_stage_key: "image"
 | 
				
			||||||
 | 
					    #first_stage_key: "jpg"  # TODO: use this later
 | 
				
			||||||
 | 
					    cond_stage_key: "caption"
 | 
				
			||||||
 | 
					    #cond_stage_key: "txt" # TODO: use this later
 | 
				
			||||||
 | 
					    image_size: 64
 | 
				
			||||||
 | 
					    channels: 16
 | 
				
			||||||
 | 
					    cond_stage_trainable: false
 | 
				
			||||||
 | 
					    conditioning_key: "hybrid-adm"
 | 
				
			||||||
 | 
					    monitor: val/loss_simple_ema
 | 
				
			||||||
 | 
					    scale_factor: 0.22765929   # magic number
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    low_scale_config:
 | 
				
			||||||
 | 
					      target: ldm.modules.encoders.modules.LowScaleEncoder
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        linear_start: 0.00085
 | 
				
			||||||
 | 
					        linear_end: 0.0120
 | 
				
			||||||
 | 
					        timesteps: 1000
 | 
				
			||||||
 | 
					        max_noise_level: 250
 | 
				
			||||||
 | 
					        output_size: 64
 | 
				
			||||||
 | 
					        model_config:
 | 
				
			||||||
 | 
					          target: ldm.models.autoencoder.AutoencoderKL
 | 
				
			||||||
 | 
					          params:
 | 
				
			||||||
 | 
					            embed_dim: 4
 | 
				
			||||||
 | 
					            monitor: val/rec_loss
 | 
				
			||||||
 | 
					            ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
 | 
				
			||||||
 | 
					            ddconfig:
 | 
				
			||||||
 | 
					              double_z: true
 | 
				
			||||||
 | 
					              z_channels: 4
 | 
				
			||||||
 | 
					              resolution: 256
 | 
				
			||||||
 | 
					              in_channels: 3
 | 
				
			||||||
 | 
					              out_ch: 3
 | 
				
			||||||
 | 
					              ch: 128
 | 
				
			||||||
 | 
					              ch_mult:
 | 
				
			||||||
 | 
					                - 1
 | 
				
			||||||
 | 
					                - 2
 | 
				
			||||||
 | 
					                - 4
 | 
				
			||||||
 | 
					                - 4
 | 
				
			||||||
 | 
					              num_res_blocks: 2
 | 
				
			||||||
 | 
					              attn_resolutions: [ ]
 | 
				
			||||||
 | 
					              dropout: 0.0
 | 
				
			||||||
 | 
					            lossconfig:
 | 
				
			||||||
 | 
					              target: torch.nn.Identity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scheduler_config: # 10000 warmup steps
 | 
				
			||||||
 | 
					      target: ldm.lr_scheduler.LambdaLinearScheduler
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        warm_up_steps: [ 10000 ]
 | 
				
			||||||
 | 
					        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
 | 
				
			||||||
 | 
					        f_start: [ 1.e-6 ]
 | 
				
			||||||
 | 
					        f_max: [ 1. ]
 | 
				
			||||||
 | 
					        f_min: [ 1. ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    unet_config:
 | 
				
			||||||
 | 
					      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        num_classes: 1000  #  timesteps for noise conditoining
 | 
				
			||||||
 | 
					        image_size: 64    # not really needed
 | 
				
			||||||
 | 
					        in_channels: 20
 | 
				
			||||||
 | 
					        out_channels: 16
 | 
				
			||||||
 | 
					        model_channels: 32 # TODO: more
 | 
				
			||||||
 | 
					        attention_resolutions: [ 4, 2, 1 ]
 | 
				
			||||||
 | 
					        num_res_blocks: 2
 | 
				
			||||||
 | 
					        channel_mult: [ 1, 2, 4, 4 ]
 | 
				
			||||||
 | 
					        num_heads: 8
 | 
				
			||||||
 | 
					        use_spatial_transformer: True
 | 
				
			||||||
 | 
					        transformer_depth: 1
 | 
				
			||||||
 | 
					        context_dim: 768
 | 
				
			||||||
 | 
					        use_checkpoint: True
 | 
				
			||||||
 | 
					        legacy: False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    first_stage_config:
 | 
				
			||||||
 | 
					      target: ldm.models.autoencoder.AutoencoderKL
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        embed_dim: 16
 | 
				
			||||||
 | 
					        monitor: val/rec_loss
 | 
				
			||||||
 | 
					        ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
 | 
				
			||||||
 | 
					        ddconfig:
 | 
				
			||||||
 | 
					          double_z: True
 | 
				
			||||||
 | 
					          z_channels: 16
 | 
				
			||||||
 | 
					          resolution: 256
 | 
				
			||||||
 | 
					          in_channels: 3
 | 
				
			||||||
 | 
					          out_ch: 3
 | 
				
			||||||
 | 
					          ch: 128
 | 
				
			||||||
 | 
					          ch_mult: [ 1,1,2,2,4 ]  # num_down = len(ch_mult)-1
 | 
				
			||||||
 | 
					          num_res_blocks: 2
 | 
				
			||||||
 | 
					          attn_resolutions: [ 16 ]
 | 
				
			||||||
 | 
					          dropout: 0.0
 | 
				
			||||||
 | 
					        lossconfig:
 | 
				
			||||||
 | 
					          target: torch.nn.Identity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cond_stage_config:
 | 
				
			||||||
 | 
					      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#data:
 | 
				
			||||||
 | 
					#  target: ldm.data.laion.WebDataModuleFromConfig
 | 
				
			||||||
 | 
					#  params:
 | 
				
			||||||
 | 
					#    tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
 | 
				
			||||||
 | 
					#    batch_size: 4
 | 
				
			||||||
 | 
					#    num_workers: 4
 | 
				
			||||||
 | 
					#    multinode: True
 | 
				
			||||||
 | 
					#    min_size: 256   # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
 | 
				
			||||||
 | 
					#    train:
 | 
				
			||||||
 | 
					#      shards: '{000000..231317}.tar -'
 | 
				
			||||||
 | 
					#      shuffle: 10000
 | 
				
			||||||
 | 
					#      image_key: jpg
 | 
				
			||||||
 | 
					#      image_transforms:
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.Resize
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 1024
 | 
				
			||||||
 | 
					#          interpolation: 3
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.RandomCrop
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 1024
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    # NOTE use enough shards to avoid empty validation loops in workers
 | 
				
			||||||
 | 
					#    validation:
 | 
				
			||||||
 | 
					#      shards: '{231318..231349}.tar -'
 | 
				
			||||||
 | 
					#      shuffle: 0
 | 
				
			||||||
 | 
					#      image_key: jpg
 | 
				
			||||||
 | 
					#      image_transforms:
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.Resize
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 1024
 | 
				
			||||||
 | 
					#          interpolation: 3
 | 
				
			||||||
 | 
					#      - target: torchvision.transforms.CenterCrop
 | 
				
			||||||
 | 
					#        params:
 | 
				
			||||||
 | 
					#          size: 1024
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					data:
 | 
				
			||||||
 | 
					  target: main.DataModuleFromConfig
 | 
				
			||||||
 | 
					  params:
 | 
				
			||||||
 | 
					    batch_size: 8
 | 
				
			||||||
 | 
					    num_workers: 7
 | 
				
			||||||
 | 
					    wrap: false
 | 
				
			||||||
 | 
					    train:
 | 
				
			||||||
 | 
					      target: ldm.data.imagenet.ImageNetSRTrain
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        size: 1024
 | 
				
			||||||
 | 
					        downscale_f: 4
 | 
				
			||||||
 | 
					        degradation: "cv_nearest"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					lightning:
 | 
				
			||||||
 | 
					  callbacks:
 | 
				
			||||||
 | 
					    image_logger:
 | 
				
			||||||
 | 
					      target: main.ImageLogger
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        batch_frequency: 10
 | 
				
			||||||
 | 
					        max_images: 4
 | 
				
			||||||
 | 
					        increase_log_steps: False
 | 
				
			||||||
 | 
					        log_first_step: False
 | 
				
			||||||
 | 
					        log_images_kwargs:
 | 
				
			||||||
 | 
					          use_ema_scope: False
 | 
				
			||||||
 | 
					          inpaint: False
 | 
				
			||||||
 | 
					          plot_progressive_rows: False
 | 
				
			||||||
 | 
					          plot_diffusion_rows: False
 | 
				
			||||||
 | 
					          N: 4
 | 
				
			||||||
 | 
					          unconditional_guidance_scale: 3.0
 | 
				
			||||||
 | 
					          unconditional_guidance_label: [""]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  trainer:
 | 
				
			||||||
 | 
					    benchmark: True
 | 
				
			||||||
 | 
					    # val_check_interval: 5000000  # really sorry # TODO: bring back in
 | 
				
			||||||
 | 
					    num_sanity_val_steps: 0
 | 
				
			||||||
 | 
					    accumulate_grad_batches: 1
 | 
				
			||||||
| 
						 | 
					@ -368,7 +368,7 @@ class ImageNetSR(Dataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        example["image"] = (image/127.5 - 1.0).astype(np.float32)
 | 
					        example["image"] = (image/127.5 - 1.0).astype(np.float32)
 | 
				
			||||||
        example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
 | 
					        example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
 | 
				
			||||||
 | 
					        example["caption"] = example["human_label"]  # dummy caption
 | 
				
			||||||
        return example
 | 
					        return example
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
 | 
					    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
 | 
				
			||||||
                  cond_key=None, return_original_cond=False, bs=None):
 | 
					                  cond_key=None, return_original_cond=False, bs=None, return_x=False):
 | 
				
			||||||
        x = super().get_input(batch, k)
 | 
					        x = super().get_input(batch, k)
 | 
				
			||||||
        if bs is not None:
 | 
					        if bs is not None:
 | 
				
			||||||
            x = x[:bs]
 | 
					            x = x[:bs]
 | 
				
			||||||
| 
						 | 
					@ -746,6 +746,8 @@ class LatentDiffusion(DDPM):
 | 
				
			||||||
        if return_first_stage_outputs:
 | 
					        if return_first_stage_outputs:
 | 
				
			||||||
            xrec = self.decode_first_stage(z)
 | 
					            xrec = self.decode_first_stage(z)
 | 
				
			||||||
            out.extend([x, xrec])
 | 
					            out.extend([x, xrec])
 | 
				
			||||||
 | 
					        if return_x:
 | 
				
			||||||
 | 
					            out.extend([x])
 | 
				
			||||||
        if return_original_cond:
 | 
					        if return_original_cond:
 | 
				
			||||||
            out.append(xc)
 | 
					            out.append(xc)
 | 
				
			||||||
        return out
 | 
					        return out
 | 
				
			||||||
| 
						 | 
					@ -1416,9 +1418,9 @@ class DiffusionWrapper(pl.LightningModule):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.diffusion_model = instantiate_from_config(diff_model_config)
 | 
					        self.diffusion_model = instantiate_from_config(diff_model_config)
 | 
				
			||||||
        self.conditioning_key = conditioning_key
 | 
					        self.conditioning_key = conditioning_key
 | 
				
			||||||
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
 | 
					        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
 | 
					    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
 | 
				
			||||||
        if self.conditioning_key is None:
 | 
					        if self.conditioning_key is None:
 | 
				
			||||||
            out = self.diffusion_model(x, t)
 | 
					            out = self.diffusion_model(x, t)
 | 
				
			||||||
        elif self.conditioning_key == 'concat':
 | 
					        elif self.conditioning_key == 'concat':
 | 
				
			||||||
| 
						 | 
					@ -1431,6 +1433,11 @@ class DiffusionWrapper(pl.LightningModule):
 | 
				
			||||||
            xc = torch.cat([x] + c_concat, dim=1)
 | 
					            xc = torch.cat([x] + c_concat, dim=1)
 | 
				
			||||||
            cc = torch.cat(c_crossattn, 1)
 | 
					            cc = torch.cat(c_crossattn, 1)
 | 
				
			||||||
            out = self.diffusion_model(xc, t, context=cc)
 | 
					            out = self.diffusion_model(xc, t, context=cc)
 | 
				
			||||||
 | 
					        elif self.conditioning_key == 'hybrid-adm':
 | 
				
			||||||
 | 
					            assert c_adm is not None
 | 
				
			||||||
 | 
					            xc = torch.cat([x] + c_concat, dim=1)
 | 
				
			||||||
 | 
					            cc = torch.cat(c_crossattn, 1)
 | 
				
			||||||
 | 
					            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
 | 
				
			||||||
        elif self.conditioning_key == 'adm':
 | 
					        elif self.conditioning_key == 'adm':
 | 
				
			||||||
            cc = c_crossattn[0]
 | 
					            cc = c_crossattn[0]
 | 
				
			||||||
            out = self.diffusion_model(x, t, y=cc)
 | 
					            out = self.diffusion_model(x, t, y=cc)
 | 
				
			||||||
| 
						 | 
					@ -1440,6 +1447,34 @@ class DiffusionWrapper(pl.LightningModule):
 | 
				
			||||||
        return out
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LatentUpscaleDiffusion(LatentDiffusion):
 | 
				
			||||||
 | 
					    def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        # assumes that neither the cond_stage nor the low_scale_model contain trainable params
 | 
				
			||||||
 | 
					        assert not self.cond_stage_trainable
 | 
				
			||||||
 | 
					        self.instantiate_low_stage(low_scale_config)
 | 
				
			||||||
 | 
					        self.low_scale_key = low_scale_key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def instantiate_low_stage(self, config):
 | 
				
			||||||
 | 
					        model = instantiate_from_config(config)
 | 
				
			||||||
 | 
					        self.low_scale_model = model.eval()
 | 
				
			||||||
 | 
					        self.low_scale_model.train = disabled_train
 | 
				
			||||||
 | 
					        for param in self.low_scale_model.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def get_input(self, batch, k, cond_key=None, bs=None):
 | 
				
			||||||
 | 
					        z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs)
 | 
				
			||||||
 | 
					        x_low = batch[self.low_scale_key]
 | 
				
			||||||
 | 
					        x_low = rearrange(x_low, 'b h w c -> b c h w')
 | 
				
			||||||
 | 
					        x_low = x_low.to(memory_format=torch.contiguous_format).float()
 | 
				
			||||||
 | 
					        zx, noise_level = self.low_scale_model(x_low)
 | 
				
			||||||
 | 
					        all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
 | 
				
			||||||
 | 
					        return z, all_conds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # TODO log it
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Layout2ImgDiffusion(LatentDiffusion):
 | 
					class Layout2ImgDiffusion(LatentDiffusion):
 | 
				
			||||||
    # TODO: move all layout-specific hacks to this class
 | 
					    # TODO: move all layout-specific hacks to this class
 | 
				
			||||||
    def __init__(self, cond_stage_key, *args, **kwargs):
 | 
					    def __init__(self, cond_stage_key, *args, **kwargs):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,10 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper  # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
 | 
					from ldm.modules.x_transformer import Encoder, TransformerWrapper  # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
 | 
				
			||||||
 | 
					from ldm.util import default
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AbstractEncoder(nn.Module):
 | 
					class AbstractEncoder(nn.Module):
 | 
				
			||||||
| 
						 | 
					@ -201,6 +203,60 @@ class SpatialRescaler(nn.Module):
 | 
				
			||||||
        return self(x)
 | 
					        return self(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ldm.util import instantiate_from_config
 | 
				
			||||||
 | 
					from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LowScaleEncoder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.max_noise_level = max_noise_level
 | 
				
			||||||
 | 
					        self.model = instantiate_from_config(model_config)
 | 
				
			||||||
 | 
					        self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
 | 
				
			||||||
 | 
					                                                            linear_end=linear_end)
 | 
				
			||||||
 | 
					        self.out_size = output_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def register_schedule(self, beta_schedule="linear", timesteps=1000,
 | 
				
			||||||
 | 
					                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 | 
				
			||||||
 | 
					        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
 | 
				
			||||||
 | 
					                                   cosine_s=cosine_s)
 | 
				
			||||||
 | 
					        alphas = 1. - betas
 | 
				
			||||||
 | 
					        alphas_cumprod = np.cumprod(alphas, axis=0)
 | 
				
			||||||
 | 
					        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        timesteps, = betas.shape
 | 
				
			||||||
 | 
					        self.num_timesteps = int(timesteps)
 | 
				
			||||||
 | 
					        self.linear_start = linear_start
 | 
				
			||||||
 | 
					        self.linear_end = linear_end
 | 
				
			||||||
 | 
					        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        to_torch = partial(torch.tensor, dtype=torch.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.register_buffer('betas', to_torch(betas))
 | 
				
			||||||
 | 
					        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
 | 
				
			||||||
 | 
					        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # calculations for diffusion q(x_t | x_{t-1}) and others
 | 
				
			||||||
 | 
					        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
 | 
				
			||||||
 | 
					        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
 | 
				
			||||||
 | 
					        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
 | 
				
			||||||
 | 
					        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
 | 
				
			||||||
 | 
					        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def q_sample(self, x_start, t, noise=None):
 | 
				
			||||||
 | 
					        noise = default(noise, lambda: torch.randn_like(x_start))
 | 
				
			||||||
 | 
					        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
 | 
				
			||||||
 | 
					                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        z = self.model.encode(x).sample()
 | 
				
			||||||
 | 
					        noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
 | 
				
			||||||
 | 
					        z = self.q_sample(z, noise_level)
 | 
				
			||||||
 | 
					        #z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")  # TODO: experiment with mode
 | 
				
			||||||
 | 
					        z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
 | 
				
			||||||
 | 
					        return z, noise_level
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    from ldm.util import count_params
 | 
					    from ldm.util import count_params
 | 
				
			||||||
    sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
 | 
					    sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										12
									
								
								scripts/prompts/weird-dalle-prompts.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								scripts/prompts/weird-dalle-prompts.txt
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,12 @@
 | 
				
			||||||
 | 
					# TODO, check out Twitter.
 | 
				
			||||||
 | 
					Darth Vader at Woodstock (1969)
 | 
				
			||||||
 | 
					Bunny Vikings
 | 
				
			||||||
 | 
					The Demogorgon from Stranger Thinhs holding a basketball
 | 
				
			||||||
 | 
					Hamster in my microwave
 | 
				
			||||||
 | 
					a courtroom sketch of a Ford Transit van
 | 
				
			||||||
 | 
					PS1 Hagrid ad MCDonalds
 | 
				
			||||||
 | 
					Karl Marx in KFC Logo
 | 
				
			||||||
 | 
					Moai Statue giving a TED talk
 | 
				
			||||||
 | 
					wahing machine trail cam
 | 
				
			||||||
 | 
					minions at cross burning
 | 
				
			||||||
 | 
					Hindenburg disaster in Fortnite
 | 
				
			||||||
		Loading…
	
		Reference in a new issue