Docker Requirements for additional postprocessing of the Stylegans
This commit is contained in:
parent
35c61c0c5b
commit
e2db2688e0
4 changed files with 2777 additions and 402 deletions
18
Dockerfile
18
Dockerfile
|
@ -12,6 +12,24 @@ ENV PYTHONDONTWRITEBYTECODE 1
|
|||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
|
||||
RUN pip install ipywidgets
|
||||
|
||||
#When X11 forwarding matplotlib
|
||||
#RUN pip install cairocffi
|
||||
|
||||
|
||||
RUN apt-get update -y
|
||||
ENV TZ=Europe/Amsterdam
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
#RUN apt-get install -y libcairo2 python3-gi python3-gi-cairo gir1.2-gtk-3.0
|
||||
RUN apt-get install -y libgirepository1.0-dev gcc libcairo2-dev pkg-config python3-dev gir1.2-gtk-3.0
|
||||
RUN pip install pycairo
|
||||
RUN pip install PyGObject
|
||||
RUN apt-get install -y mesa-utils
|
||||
|
||||
# ffmpeg for cv2 video creation
|
||||
RUN apt-get install -y ffmpeg
|
||||
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
|
3102
Stylegan3.ipynb
3102
Stylegan3.ipynb
File diff suppressed because one or more lines are too long
|
@ -26,6 +26,9 @@ import numpy as np
|
|||
import PIL.Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
PIL.Image.init() # required to initialise PIL.Image.EXTENSION
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def error(msg):
|
||||
|
@ -216,8 +219,11 @@ def open_mnist(images_gz: str, *, max_images: Optional[int]):
|
|||
def make_transform(
|
||||
transform: Optional[str],
|
||||
output_width: Optional[int],
|
||||
output_height: Optional[int]
|
||||
output_height: Optional[int],
|
||||
crop_width: Optional[int],
|
||||
crop_height: Optional[int]
|
||||
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
||||
|
||||
def scale(width, height, img):
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
@ -249,17 +255,38 @@ def make_transform(
|
|||
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
||||
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
||||
return canvas
|
||||
|
||||
def scale_center_crop(width, height, crop_w, crop_h, img):
|
||||
return scale(width, height, img[(img.shape[0] - crop_w) // 2 : (img.shape[0] + crop_w) // 2, (img.shape[1] - crop_h) // 2 : (img.shape[1] + crop_h) // 2])
|
||||
|
||||
def scale_center_crop_wide(width, height, crop_w, crop_h, img):
|
||||
error('not implemented')
|
||||
|
||||
if transform is None:
|
||||
return functools.partial(scale, output_width, output_height)
|
||||
if transform == 'center-crop':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + 'transform')
|
||||
return functools.partial(center_crop, output_width, output_height)
|
||||
if transform == 'center-crop-wide':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + ' transform')
|
||||
return functools.partial(center_crop_wide, output_width, output_height)
|
||||
crop_width = output_width if crop_width is None else crop_width
|
||||
crop_height = output_height if crop_height is None else crop_height
|
||||
|
||||
if crop_width != output_width or crop_height != output_height:
|
||||
if transform is None:
|
||||
error ('must specify transform method (center-crop or center-crop-wide)')
|
||||
if transform == 'center-crop':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + 'transform')
|
||||
return functools.partial(scale_center_crop, output_height, output_width, crop_width, crop_height)
|
||||
if transform == 'center-crop-wide':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + ' transform')
|
||||
return functools.partial(scale_center_crop_wide, output_height, output_width, crop_width, crop_height)
|
||||
else:
|
||||
if transform is None:
|
||||
return functools.partial(scale, output_width, output_height)
|
||||
if transform == 'center-crop':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + 'transform')
|
||||
return functools.partial(center_crop, output_width, output_height)
|
||||
if transform == 'center-crop-wide':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + ' transform')
|
||||
return functools.partial(center_crop_wide, output_width, output_height)
|
||||
assert False, 'unknown transform'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
@ -323,13 +350,15 @@ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None],
|
|||
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
|
||||
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
|
||||
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
|
||||
@click.option('--crop-resolution', help='Resolution of crop (can be larger small than final resoltion) (e.g., \'600x600\')', metavar='WxH', type=parse_tuple)
|
||||
def convert_dataset(
|
||||
ctx: click.Context,
|
||||
source: str,
|
||||
dest: str,
|
||||
max_images: Optional[int],
|
||||
transform: Optional[str],
|
||||
resolution: Optional[Tuple[int, int]]
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
crop_resolution: Optional[Tuple[int, int]],
|
||||
):
|
||||
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
|
||||
|
||||
|
@ -387,7 +416,7 @@ def convert_dataset(
|
|||
|
||||
\b
|
||||
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
|
||||
--transform=center-crop-wide --resolution=512x384
|
||||
--transform=center-crop-wide --resolution=512x384 --crop-resolution=600x600
|
||||
"""
|
||||
|
||||
PIL.Image.init() # type: ignore
|
||||
|
@ -399,7 +428,8 @@ def convert_dataset(
|
|||
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
||||
|
||||
if resolution is None: resolution = (None, None)
|
||||
transform_image = make_transform(transform, *resolution)
|
||||
if crop_resolution is None: crop_resolution = (None, None)
|
||||
transform_image = make_transform(transform, *resolution, *crop_resolution)
|
||||
|
||||
dataset_attrs = None
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ class Dataset(torch.utils.data.Dataset):
|
|||
if xflip:
|
||||
self._raw_idx = np.tile(self._raw_idx, 2)
|
||||
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
||||
# TODO, perform a similar trick, but then with random crops etc.
|
||||
|
||||
def _get_raw_labels(self):
|
||||
if self._raw_labels is None:
|
||||
|
|
Loading…
Reference in a new issue