more modular gen_images and started with notebook

This commit is contained in:
Ruben van de Ven 2022-06-24 15:51:27 +02:00
parent 407db86e6f
commit 35c61c0c5b
3 changed files with 723 additions and 5 deletions

View file

@ -1,3 +1,9 @@
# Changes
* Included Peter Baylies' projector from [Schultz]'s StyleGAN repo
* adapted gen_images.py to allow use from other python code
## Alias-Free Generative Adversarial Networks (StyleGAN3)<br><sub>Official PyTorch implementation of the NeurIPS 2021 paper</sub> ## Alias-Free Generative Adversarial Networks (StyleGAN3)<br><sub>Official PyTorch implementation of the NeurIPS 2021 paper</sub>
![Teaser image](./docs/stylegan3-teaser-1920x1006.png) ![Teaser image](./docs/stylegan3-teaser-1920x1006.png)

688
Stylegan3.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -77,7 +77,7 @@ def make_transform(translate: Tuple[float,float], angle: float):
@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
def generate_images( def generate_images_cmd(
network_pkl: str, network_pkl: str,
seeds: List[int], seeds: List[int],
truncation_psi: float, truncation_psi: float,
@ -101,13 +101,35 @@ def generate_images(
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
""" """
for z, img, filename in generate_images(network_pkl,
seeds,
outdir,
truncation_psi,
noise_mode,
translate,
rotate,
class_idx):
img.save(filename)
def generate_images(
network_pkl: str,
seeds: List[int],
outdir: str,
truncation_psi: float = 1.0,
noise_mode: str = 'const',
translate: Tuple[float,float] = (0,0),
rotate: float = 0,
class_idx: Optional[int] = None
):
print('Loading networks from "%s"...' % network_pkl) print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda') device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f: with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
os.makedirs(outdir, exist_ok=True) network_name = os.path.split(os.path.dirname(network_pkl))[1]
network_iteration_nr = re.findall(r'\d+', os.path.basename(network_pkl))[0]
os.makedirs(os.path.join(outdir,network_name), exist_ok=True)
# Labels. # Labels.
label = torch.zeros([1, G.c_dim], device=device) label = torch.zeros([1, G.c_dim], device=device)
@ -115,9 +137,11 @@ def generate_images(
if class_idx is None: if class_idx is None:
raise click.ClickException('Must specify class label with --class when using a conditional network') raise click.ClickException('Must specify class label with --class when using a conditional network')
label[:, class_idx] = 1 label[:, class_idx] = 1
filename_class = f"class{class_idx}"
else: else:
if class_idx is not None: if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network') print ('warn: --class=lbl ignored when running on an unconditional network')
filename_class = ""
# Generate images. # Generate images.
for seed_idx, seed in enumerate(seeds): for seed_idx, seed in enumerate(seeds):
@ -134,12 +158,12 @@ def generate_images(
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') yield z, PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'), f'{outdir}/{network_name}/{network_iteration_nr}-{filename_class}-seed{seed:04d}-trunc{truncation_psi}.png'
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------
if __name__ == "__main__": if __name__ == "__main__":
generate_images() # pylint: disable=no-value-for-parameter generate_images_cmd() # pylint: disable=no-value-for-parameter
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------