# Modified StyleGAN2 Projector with CLIP, addl. losses, kmeans, etc. # by Peter Baylies, 2021 -- @pbaylies on Twitter # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Project given image to the latent space of pretrained network pickle.""" import copy import os from time import perf_counter import click import imageio import numpy as np import PIL.Image from PIL import ImageFilter import torch import torch.nn.functional as F import dnnlib import legacy image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() def score_images(G, model, text, latents, device, label_class = 0, batch_size = 8): scores = [] all_images = [] for i in range(latents.shape[0]//batch_size): images = G.synthesis(torch.tensor(latents[i*batch_size:(i+1)*batch_size,:,:], dtype=torch.float32, device=device), noise_mode='const') with torch.no_grad(): image_input = (torch.clamp(images, -1, 1) + 1) * 0.5 image_input = F.interpolate(image_input, size=(256, 256), mode='area') image_input = image_input[:, :, 16:240, 16:240] # 256 -> 224, center crop image_input -= image_mean[None, :, None, None] image_input /= image_std[None, :, None, None] score = model(image_input, text)[0] scores.append(score.cpu().numpy()) all_images.append(images.cpu().numpy()) scores = np.array(scores) scores = scores.reshape(-1, *scores.shape[2:]).squeeze() scores = 1 - scores / np.linalg.norm(scores) all_images = np.array(all_images) all_images = all_images.reshape(-1, *all_images.shape[2:]) return scores, all_images def project( G, target_image: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution target_text, *, num_steps = 300, w_avg_samples = 8192, initial_learning_rate = 0.02, initial_latent = None, initial_noise_factor = 0.01, lr_rampdown_length = 0.10, lr_rampup_length = 0.5, noise_ramp_length = 0.75, latent_range = 2.0, max_noise = 0.5, min_threshold = 0.6, use_vgg = True, use_clip = True, use_pixel = True, use_penalty = True, use_center = True, regularize_noise_weight = 1e5, kmeans = True, kmeans_clusters = 64, verbose = False, device: torch.device ): if target_image is not None: assert target_image.shape == (G.img_channels, G.img_resolution, G.img_resolution) else: use_vgg = False use_pixel = False # reduce errors unless using clip if use_clip: import clip def logprint(*args): if verbose: print(*args) G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore # Compute w stats. logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) labels = None if (G.mapping.c_dim): labels = torch.from_numpy(0.5*np.random.RandomState(123).randn(w_avg_samples, G.mapping.c_dim)).to(device) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), labels) # [N, L, C] w_samples = w_samples.cpu().numpy().astype(np.float32) # [N, L, C] w_samples_1d = w_samples[:, :1, :].astype(np.float32) w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, L, C] w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 kmeans_latents = None if initial_latent is not None: w_avg = initial_latent else: if kmeans and use_clip and target_text is not None: from kmeans_pytorch import kmeans # data data_size, dims, num_clusters = w_avg_samples, G.z_dim, kmeans_clusters x = w_samples_1d x = torch.from_numpy(x) # kmeans logprint(f'Performing kmeans clustering using {w_avg_samples} latents into {kmeans_clusters} clusters...') cluster_ids_x, cluster_centers = kmeans( X=x, num_clusters=num_clusters, distance='euclidean', device=device ) #logprint(f'\nGenerating images from kmeans latents...') kmeans_latents = torch.tensor(cluster_centers, dtype=torch.float32, device=device, requires_grad=True) # Setup noise inputs. noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } # Load VGG16 feature detector. if use_vgg: url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' with dnnlib.util.open_url(url) as f: vgg16 = torch.jit.load(f).eval().to(device) # Load CLIP if use_clip: model, transform = clip.load("ViT-B/32", device=device) # Features for target image. if target_image is not None: target_images = target_image.unsqueeze(0).to(device).to(torch.float32) small_target = F.interpolate(target_images, size=(64, 64), mode='area') if use_center: center_target = F.interpolate(target_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336] target_images = F.interpolate(target_images, size=(256, 256), mode='area') target_images = target_images[:, :, 16:240, 16:240] # 256 -> 224, center crop if use_vgg: vgg_target_features = vgg16(target_images, resize_images=False, return_lpips=True) if use_center: vgg_target_center = vgg16(center_target, resize_images=False, return_lpips=True) if use_clip: if target_image is not None: with torch.no_grad(): clip_target_features = model.encode_image(((target_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float() if use_center: clip_target_center = model.encode_image(((center_target / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float() if kmeans_latents is not None and use_clip and target_text is not None: scores, kmeans_images = score_images(G, model, target_text, kmeans_latents.repeat([1, G.mapping.num_ws, 1]), device=device) ind = np.argpartition(scores, 4)[:4] w_avg = torch.median(kmeans_latents[ind],dim=0,keepdim=True)[0].repeat([1, G.mapping.num_ws, 1]) w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable w_avg_tensor = w_opt.clone() w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) optimizer = torch.optim.AdamW([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) # Init noise. for buf in noise_bufs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True for step in range(num_steps): # Learning rate schedule. t = step / num_steps w_noise_scale = max_noise * w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = initial_learning_rate * lr_ramp for param_group in optimizer.param_groups: param_group['lr'] = lr # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = w_opt + w_noise synth_images = G.synthesis(torch.clamp(ws,-latent_range,latent_range), noise_mode='const') # Downsample image to 256x256 if it's larger than that. CLIP was built for 224x224 images. synth_images = (torch.clamp(synth_images, -1, 1) + 1) * (255/2) small_synth = F.interpolate(synth_images, size=(64, 64), mode='area') if use_center: center_synth = F.interpolate(synth_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336] synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_images = synth_images[:, :, 16:240, 16:240] # 256 -> 224, center crop dist = 0 if use_vgg: vgg_synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) vgg_dist = (vgg_target_features - vgg_synth_features).square().sum() if use_center: vgg_synth_center = vgg16(center_synth, resize_images=False, return_lpips=True) vgg_dist += (vgg_target_center - vgg_synth_center).square().sum() vgg_dist *= 6 dist += F.relu(vgg_dist*vgg_dist - min_threshold) if use_clip: clip_synth_image = ((synth_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None] clip_synth_features = model.encode_image(clip_synth_image).float() adj_center = 2.0 if use_center: clip_cynth_center_image = ((center_synth / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None] adj_center = 1.0 clip_synth_center = model.encode_image(clip_cynth_center_image).float() if target_image is not None: clip_dist = (clip_target_features - clip_synth_features).square().sum() if use_center: clip_dist += (clip_target_center - clip_synth_center).square().sum() dist += F.relu(0.5 + adj_center*clip_dist - min_threshold) if target_text is not None: clip_text = 1 - model(clip_synth_image, target_text)[0].sum() / 100 if use_center: clip_text += 1 - model(clip_cynth_center_image, target_text)[0].sum() / 100 dist += 2*F.relu(adj_center*clip_text*clip_text - min_threshold / adj_center) if use_pixel: pixel_dist = (target_images - synth_images).abs().sum() / 2000000.0 if use_center: pixel_dist += (center_target - center_synth).abs().sum() / 2000000.0 pixel_dist += (small_target - small_synth).square().sum() / 800000.0 pixel_dist /= 4 dist += F.relu(lr_ramp * pixel_dist - min_threshold) if use_penalty: l1_penalty = (w_opt - w_avg_tensor).abs().sum() / 5000.0 dist += F.relu(lr_ramp * l1_penalty - min_threshold) # Noise regularization. reg_loss = 0.0 for v in noise_bufs.values(): noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) #print(vgg_dist, clip_dist, pixel_dist, l1_penalty, reg_loss * regularize_noise_weight) loss = dist + reg_loss * regularize_noise_weight # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') with torch.no_grad(): torch.clamp(w_opt,-latent_range,latent_range,out=w_opt) # Save projected W for each optimization step. w_out[step] = w_opt.detach()[0] # Normalize noise. with torch.no_grad(): for buf in noise_bufs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() return w_out #---------------------------------------------------------------------------- @click.command() @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) @click.option('--target-image', 'target_fname', help='Target image file to project to', required=False, metavar='FILE', default=None) @click.option('--target-text', help='Target text to project to', required=False, default=None) @click.option('--initial-latent', help='Initial latent', default=None) @click.option('--lr', help='Learning rate', type=float, default=0.1, show_default=True) @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True) @click.option('--seed', help='Random seed', type=int, default=303, show_default=True) @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True) @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR') @click.option('--use-vgg', help='Use VGG16 in the loss', type=bool, default=True, show_default=True) @click.option('--use-clip', help='Use CLIP in the loss', type=bool, default=True, show_default=True) @click.option('--use-pixel', help='Use L1/L2 distance on pixels in the loss', type=bool, default=True, show_default=True) @click.option('--use-penalty', help='Use a penalty on latent values distance from the mean in the loss', type=bool, default=True, show_default=True) @click.option('--use-center', help='Optimize against an additional center image crop', type=bool, default=True, show_default=True) @click.option('--use-kmeans', help='Perform kmeans clustering for selecting initial latents', type=bool, default=True, show_default=True) def run_projection( network_pkl: str, target_fname: str, target_text: str, initial_latent: str, outdir: str, save_video: bool, seed: int, lr: float, num_steps: int, use_vgg: bool, use_clip: bool, use_pixel: bool, use_penalty: bool, use_center: bool, use_kmeans: bool, ): """Project given image to the latent space of pretrained network pickle. Examples: \b python projector.py --outdir=out --target=~/mytargetimg.png \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl """ np.random.seed(seed) torch.manual_seed(seed) # Load networks. print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore # Load target image. target_image = None if target_fname: target_pil = PIL.Image.open(target_fname).convert('RGB').filter(ImageFilter.SHARPEN) w, h = target_pil.size s = min(w, h) target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) target_uint8 = np.array(target_pil, dtype=np.uint8) target_image = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device) if target_text: target_text = torch.cat([clip.tokenize(target_text)]).to(device) if initial_latent is not None: initial_latent = np.load(initial_latent) initial_latent = initial_latent[initial_latent.files[0]] # Optimize projection. start_time = perf_counter() projected_w_steps = project( G, target_image=target_image, target_text=target_text, initial_latent=initial_latent, initial_learning_rate=lr, num_steps=num_steps, use_vgg=use_vgg, use_clip=use_clip, use_pixel=use_pixel, use_penalty=use_penalty, use_center=use_center, kmeans=use_kmeans, device=device, verbose=True ) print (f'Elapsed: {(perf_counter()-start_time):.1f} s') os.makedirs(outdir, exist_ok=True) # Save final projected frame and W vector. if target_fname: target_pil.save(f'{outdir}/target.png') projected_w = projected_w_steps[-1] synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png') np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy()) # Render debug output: optional video and projected image and W vector. if save_video: video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') print (f'Saving optimization progress video "{outdir}/proj.mp4"') for projected_w in projected_w_steps: synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() if target_fname: video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) else: video.append_data(synth_image) video.close() #---------------------------------------------------------------------------- if __name__ == "__main__": run_projection() # pylint: disable=no-value-for-parameter #----------------------------------------------------------------------------