diff --git a/pbaylies_projector.py b/pbaylies_projector.py new file mode 100644 index 0000000..52da429 --- /dev/null +++ b/pbaylies_projector.py @@ -0,0 +1,400 @@ +# Modified StyleGAN2 Projector with CLIP, addl. losses, kmeans, etc. +# by Peter Baylies, 2021 -- @pbaylies on Twitter + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Project given image to the latent space of pretrained network pickle.""" + +import copy +import os +from time import perf_counter + +import click +import imageio +import numpy as np +import PIL.Image +from PIL import ImageFilter +import torch +import torch.nn.functional as F + +import dnnlib +import legacy + +image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() +image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() + +def score_images(G, model, text, latents, device, label_class = 0, batch_size = 8): + scores = [] + all_images = [] + for i in range(latents.shape[0]//batch_size): + images = G.synthesis(torch.tensor(latents[i*batch_size:(i+1)*batch_size,:,:], dtype=torch.float32, device=device), noise_mode='const') + with torch.no_grad(): + image_input = (torch.clamp(images, -1, 1) + 1) * 0.5 + image_input = F.interpolate(image_input, size=(256, 256), mode='area') + image_input = image_input[:, :, 16:240, 16:240] # 256 -> 224, center crop + image_input -= image_mean[None, :, None, None] + image_input /= image_std[None, :, None, None] + score = model(image_input, text)[0] + scores.append(score.cpu().numpy()) + all_images.append(images.cpu().numpy()) + + scores = np.array(scores) + scores = scores.reshape(-1, *scores.shape[2:]).squeeze() + scores = 1 - scores / np.linalg.norm(scores) + all_images = np.array(all_images) + all_images = all_images.reshape(-1, *all_images.shape[2:]) + return scores, all_images + +def project( + G, + target_image: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution + target_text, + *, + num_steps = 300, + w_avg_samples = 8192, + initial_learning_rate = 0.02, + initial_latent = None, + initial_noise_factor = 0.01, + lr_rampdown_length = 0.10, + lr_rampup_length = 0.5, + noise_ramp_length = 0.75, + latent_range = 2.0, + max_noise = 0.5, + min_threshold = 0.6, + use_vgg = True, + use_clip = True, + use_pixel = True, + use_penalty = True, + use_center = True, + regularize_noise_weight = 1e5, + kmeans = True, + kmeans_clusters = 64, + verbose = False, + device: torch.device +): + if target_image is not None: + assert target_image.shape == (G.img_channels, G.img_resolution, G.img_resolution) + else: + use_vgg = False + use_pixel = False + + # reduce errors unless using clip + if use_clip: + import clip + + def logprint(*args): + if verbose: + print(*args) + + G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore + + # Compute w stats. + logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') + z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + labels = None + if (G.mapping.c_dim): + labels = torch.from_numpy(0.5*np.random.RandomState(123).randn(w_avg_samples, G.mapping.c_dim)).to(device) + w_samples = G.mapping(torch.from_numpy(z_samples).to(device), labels) # [N, L, C] + w_samples = w_samples.cpu().numpy().astype(np.float32) # [N, L, C] + w_samples_1d = w_samples[:, :1, :].astype(np.float32) + + w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, L, C] + w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + kmeans_latents = None + if initial_latent is not None: + w_avg = initial_latent + else: + if kmeans and use_clip and target_text is not None: + from kmeans_pytorch import kmeans + # data + data_size, dims, num_clusters = w_avg_samples, G.z_dim, kmeans_clusters + x = w_samples_1d + x = torch.from_numpy(x) + + # kmeans + logprint(f'Performing kmeans clustering using {w_avg_samples} latents into {kmeans_clusters} clusters...') + cluster_ids_x, cluster_centers = kmeans( + X=x, num_clusters=num_clusters, distance='euclidean', device=device + ) + #logprint(f'\nGenerating images from kmeans latents...') + kmeans_latents = torch.tensor(cluster_centers, dtype=torch.float32, device=device, requires_grad=True) + + # Setup noise inputs. + noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } + + # Load VGG16 feature detector. + if use_vgg: + url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + with dnnlib.util.open_url(url) as f: + vgg16 = torch.jit.load(f).eval().to(device) + + # Load CLIP + if use_clip: + model, transform = clip.load("ViT-B/32", device=device) + + # Features for target image. + if target_image is not None: + target_images = target_image.unsqueeze(0).to(device).to(torch.float32) + small_target = F.interpolate(target_images, size=(64, 64), mode='area') + if use_center: + center_target = F.interpolate(target_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336] + target_images = F.interpolate(target_images, size=(256, 256), mode='area') + target_images = target_images[:, :, 16:240, 16:240] # 256 -> 224, center crop + + if use_vgg: + vgg_target_features = vgg16(target_images, resize_images=False, return_lpips=True) + if use_center: + vgg_target_center = vgg16(center_target, resize_images=False, return_lpips=True) + + if use_clip: + if target_image is not None: + with torch.no_grad(): + clip_target_features = model.encode_image(((target_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float() + if use_center: + clip_target_center = model.encode_image(((center_target / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float() + + if kmeans_latents is not None and use_clip and target_text is not None: + scores, kmeans_images = score_images(G, model, target_text, kmeans_latents.repeat([1, G.mapping.num_ws, 1]), device=device) + ind = np.argpartition(scores, 4)[:4] + w_avg = torch.median(kmeans_latents[ind],dim=0,keepdim=True)[0].repeat([1, G.mapping.num_ws, 1]) + + w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable + w_avg_tensor = w_opt.clone() + w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) + optimizer = torch.optim.AdamW([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) + + # Init noise. + for buf in noise_bufs.values(): + buf[:] = torch.randn_like(buf) + buf.requires_grad = True + + for step in range(num_steps): + # Learning rate schedule. + t = step / num_steps + w_noise_scale = max_noise * w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 + lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) + lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) + lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) + lr = initial_learning_rate * lr_ramp + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # Synth images from opt_w. + w_noise = torch.randn_like(w_opt) * w_noise_scale + ws = w_opt + w_noise + synth_images = G.synthesis(torch.clamp(ws,-latent_range,latent_range), noise_mode='const') + + # Downsample image to 256x256 if it's larger than that. CLIP was built for 224x224 images. + synth_images = (torch.clamp(synth_images, -1, 1) + 1) * (255/2) + small_synth = F.interpolate(synth_images, size=(64, 64), mode='area') + if use_center: + center_synth = F.interpolate(synth_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336] + synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') + + # Features for synth images. + synth_images = synth_images[:, :, 16:240, 16:240] # 256 -> 224, center crop + + dist = 0 + + if use_vgg: + vgg_synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) + vgg_dist = (vgg_target_features - vgg_synth_features).square().sum() + if use_center: + vgg_synth_center = vgg16(center_synth, resize_images=False, return_lpips=True) + vgg_dist += (vgg_target_center - vgg_synth_center).square().sum() + vgg_dist *= 6 + dist += F.relu(vgg_dist*vgg_dist - min_threshold) + + if use_clip: + clip_synth_image = ((synth_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None] + clip_synth_features = model.encode_image(clip_synth_image).float() + adj_center = 2.0 + + if use_center: + clip_cynth_center_image = ((center_synth / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None] + adj_center = 1.0 + clip_synth_center = model.encode_image(clip_cynth_center_image).float() + + if target_image is not None: + clip_dist = (clip_target_features - clip_synth_features).square().sum() + if use_center: + clip_dist += (clip_target_center - clip_synth_center).square().sum() + dist += F.relu(0.5 + adj_center*clip_dist - min_threshold) + + if target_text is not None: + clip_text = 1 - model(clip_synth_image, target_text)[0].sum() / 100 + if use_center: + clip_text += 1 - model(clip_cynth_center_image, target_text)[0].sum() / 100 + dist += 2*F.relu(adj_center*clip_text*clip_text - min_threshold / adj_center) + + if use_pixel: + pixel_dist = (target_images - synth_images).abs().sum() / 2000000.0 + if use_center: + pixel_dist += (center_target - center_synth).abs().sum() / 2000000.0 + pixel_dist += (small_target - small_synth).square().sum() / 800000.0 + pixel_dist /= 4 + dist += F.relu(lr_ramp * pixel_dist - min_threshold) + + if use_penalty: + l1_penalty = (w_opt - w_avg_tensor).abs().sum() / 5000.0 + dist += F.relu(lr_ramp * l1_penalty - min_threshold) + + # Noise regularization. + reg_loss = 0.0 + for v in noise_bufs.values(): + noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 + reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + #print(vgg_dist, clip_dist, pixel_dist, l1_penalty, reg_loss * regularize_noise_weight) + loss = dist + reg_loss * regularize_noise_weight + + # Step + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') + with torch.no_grad(): + torch.clamp(w_opt,-latent_range,latent_range,out=w_opt) + # Save projected W for each optimization step. + w_out[step] = w_opt.detach()[0] + # Normalize noise. + with torch.no_grad(): + for buf in noise_bufs.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + return w_out + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--target-image', 'target_fname', help='Target image file to project to', required=False, metavar='FILE', default=None) +@click.option('--target-text', help='Target text to project to', required=False, default=None) +@click.option('--initial-latent', help='Initial latent', default=None) +@click.option('--lr', help='Learning rate', type=float, default=0.1, show_default=True) +@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True) +@click.option('--seed', help='Random seed', type=int, default=303, show_default=True) +@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True) +@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR') +@click.option('--use-vgg', help='Use VGG16 in the loss', type=bool, default=True, show_default=True) +@click.option('--use-clip', help='Use CLIP in the loss', type=bool, default=True, show_default=True) +@click.option('--use-pixel', help='Use L1/L2 distance on pixels in the loss', type=bool, default=True, show_default=True) +@click.option('--use-penalty', help='Use a penalty on latent values distance from the mean in the loss', type=bool, default=True, show_default=True) +@click.option('--use-center', help='Optimize against an additional center image crop', type=bool, default=True, show_default=True) +@click.option('--use-kmeans', help='Perform kmeans clustering for selecting initial latents', type=bool, default=True, show_default=True) +def run_projection( + network_pkl: str, + target_fname: str, + target_text: str, + initial_latent: str, + outdir: str, + save_video: bool, + seed: int, + lr: float, + num_steps: int, + use_vgg: bool, + use_clip: bool, + use_pixel: bool, + use_penalty: bool, + use_center: bool, + use_kmeans: bool, +): + """Project given image to the latent space of pretrained network pickle. + + Examples: + + \b + python projector.py --outdir=out --target=~/mytargetimg.png \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl + """ + np.random.seed(seed) + torch.manual_seed(seed) + + # Load networks. + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as fp: + G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore + + # Load target image. + target_image = None + if target_fname: + target_pil = PIL.Image.open(target_fname).convert('RGB').filter(ImageFilter.SHARPEN) + + w, h = target_pil.size + s = min(w, h) + target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) + target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) + target_uint8 = np.array(target_pil, dtype=np.uint8) + target_image = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device) + + if target_text: + target_text = torch.cat([clip.tokenize(target_text)]).to(device) + + if initial_latent is not None: + initial_latent = np.load(initial_latent) + initial_latent = initial_latent[initial_latent.files[0]] + + # Optimize projection. + start_time = perf_counter() + projected_w_steps = project( + G, + target_image=target_image, + target_text=target_text, + initial_latent=initial_latent, + initial_learning_rate=lr, + num_steps=num_steps, + use_vgg=use_vgg, + use_clip=use_clip, + use_pixel=use_pixel, + use_penalty=use_penalty, + use_center=use_center, + kmeans=use_kmeans, + device=device, + verbose=True + ) + print (f'Elapsed: {(perf_counter()-start_time):.1f} s') + + os.makedirs(outdir, exist_ok=True) + # Save final projected frame and W vector. + if target_fname: + target_pil.save(f'{outdir}/target.png') + projected_w = projected_w_steps[-1] + synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') + synth_image = (synth_image + 1) * (255/2) + synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() + PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png') + np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy()) + + # Render debug output: optional video and projected image and W vector. + if save_video: + video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') + print (f'Saving optimization progress video "{outdir}/proj.mp4"') + for projected_w in projected_w_steps: + synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') + synth_image = (synth_image + 1) * (255/2) + synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() + if target_fname: + video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) + else: + video.append_data(synth_image) + video.close() + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + run_projection() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file