Docker Requirements for additional postprocessing of the Stylegans

This commit is contained in:
Ruben van de Ven 2022-11-25 19:51:59 +01:00
parent 35c61c0c5b
commit e2db2688e0
4 changed files with 2777 additions and 402 deletions

View file

@ -12,6 +12,24 @@ ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1 ENV PYTHONUNBUFFERED 1
RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
RUN pip install ipywidgets
#When X11 forwarding matplotlib
#RUN pip install cairocffi
RUN apt-get update -y
ENV TZ=Europe/Amsterdam
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
#RUN apt-get install -y libcairo2 python3-gi python3-gi-cairo gir1.2-gtk-3.0
RUN apt-get install -y libgirepository1.0-dev gcc libcairo2-dev pkg-config python3-dev gir1.2-gtk-3.0
RUN pip install pycairo
RUN pip install PyGObject
RUN apt-get install -y mesa-utils
# ffmpeg for cv2 video creation
RUN apt-get install -y ffmpeg
WORKDIR /workspace WORKDIR /workspace

File diff suppressed because one or more lines are too long

View file

@ -26,6 +26,9 @@ import numpy as np
import PIL.Image import PIL.Image
from tqdm import tqdm from tqdm import tqdm
PIL.Image.init() # required to initialise PIL.Image.EXTENSION
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------
def error(msg): def error(msg):
@ -216,8 +219,11 @@ def open_mnist(images_gz: str, *, max_images: Optional[int]):
def make_transform( def make_transform(
transform: Optional[str], transform: Optional[str],
output_width: Optional[int], output_width: Optional[int],
output_height: Optional[int] output_height: Optional[int],
crop_width: Optional[int],
crop_height: Optional[int]
) -> Callable[[np.ndarray], Optional[np.ndarray]]: ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
def scale(width, height, img): def scale(width, height, img):
w = img.shape[1] w = img.shape[1]
h = img.shape[0] h = img.shape[0]
@ -250,6 +256,27 @@ def make_transform(
canvas[(width - height) // 2 : (width + height) // 2, :] = img canvas[(width - height) // 2 : (width + height) // 2, :] = img
return canvas return canvas
def scale_center_crop(width, height, crop_w, crop_h, img):
return scale(width, height, img[(img.shape[0] - crop_w) // 2 : (img.shape[0] + crop_w) // 2, (img.shape[1] - crop_h) // 2 : (img.shape[1] + crop_h) // 2])
def scale_center_crop_wide(width, height, crop_w, crop_h, img):
error('not implemented')
crop_width = output_width if crop_width is None else crop_width
crop_height = output_height if crop_height is None else crop_height
if crop_width != output_width or crop_height != output_height:
if transform is None:
error ('must specify transform method (center-crop or center-crop-wide)')
if transform == 'center-crop':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + 'transform')
return functools.partial(scale_center_crop, output_height, output_width, crop_width, crop_height)
if transform == 'center-crop-wide':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(scale_center_crop_wide, output_height, output_width, crop_width, crop_height)
else:
if transform is None: if transform is None:
return functools.partial(scale, output_width, output_height) return functools.partial(scale, output_width, output_height)
if transform == 'center-crop': if transform == 'center-crop':
@ -323,13 +350,15 @@ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None],
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None) @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple) @click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
@click.option('--crop-resolution', help='Resolution of crop (can be larger small than final resoltion) (e.g., \'600x600\')', metavar='WxH', type=parse_tuple)
def convert_dataset( def convert_dataset(
ctx: click.Context, ctx: click.Context,
source: str, source: str,
dest: str, dest: str,
max_images: Optional[int], max_images: Optional[int],
transform: Optional[str], transform: Optional[str],
resolution: Optional[Tuple[int, int]] resolution: Optional[Tuple[int, int]],
crop_resolution: Optional[Tuple[int, int]],
): ):
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
@ -387,7 +416,7 @@ def convert_dataset(
\b \b
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
--transform=center-crop-wide --resolution=512x384 --transform=center-crop-wide --resolution=512x384 --crop-resolution=600x600
""" """
PIL.Image.init() # type: ignore PIL.Image.init() # type: ignore
@ -399,7 +428,8 @@ def convert_dataset(
archive_root_dir, save_bytes, close_dest = open_dest(dest) archive_root_dir, save_bytes, close_dest = open_dest(dest)
if resolution is None: resolution = (None, None) if resolution is None: resolution = (None, None)
transform_image = make_transform(transform, *resolution) if crop_resolution is None: crop_resolution = (None, None)
transform_image = make_transform(transform, *resolution, *crop_resolution)
dataset_attrs = None dataset_attrs = None

View file

@ -49,6 +49,7 @@ class Dataset(torch.utils.data.Dataset):
if xflip: if xflip:
self._raw_idx = np.tile(self._raw_idx, 2) self._raw_idx = np.tile(self._raw_idx, 2)
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
# TODO, perform a similar trick, but then with random crops etc.
def _get_raw_labels(self): def _get_raw_labels(self):
if self._raw_labels is None: if self._raw_labels is None: