Split paths, create grid, use inkscape layers

This commit is contained in:
Ruben van de Ven 2019-08-26 13:45:04 +02:00
parent fa87808bb5
commit 40e6753ded
3 changed files with 210 additions and 89 deletions

View file

@ -1,4 +1,36 @@
Successfull on naam-simple: Turn numbered svgs into usable arrays:
``` ```
sketch_rnn_train --log_root=models/naam-simple --data_dir=datasets/naam-simple --hparams="data_set=[diede.npz],dec_model=layer_norm,dec_rnn_size=200,enc_model=layer_norm,enc_rnn_size=200,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=False,num_steps=1000" python create_dataset.py --dataset_dir datasets/naam6/
``` ```
Train algorithm: (save often, as we'll use the intermediate steps)
```
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
```
Generate a card:
```
python create_card.py --data_dir datasets/naam6 --model_dir models/naam6 --max_checkpoint_factor .8 --columns 5 --rows 13 --split_paths --last_is_target --last_in_group
```
max_checkpoint_factor
: set was trained for too many iterations in order to generate a nice card (~half of the card looks already smooth), by lowering this factor, we use eg. only the first 80% (.8) iteration
split_paths
: Drawings that consist of mulitple strokes are split over paths, which are split over a given number of groups (see nr_of_paths)
last_is_target
: Last item (bottom right) is not generated but hand picked from the dataset (see target_sample)
last_in_group
: Puts the last drawing in a separate group
<!--
Successfull on naam4:
```
#sketch_rnn_train --log_root=models/naam-simple --data_dir=datasets/naam-simple --hparams="data_set=[diede.npz],dec_model=layer_norm,dec_rnn_size=200,enc_model=layer_norm,enc_rnn_size=200,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=False,num_steps=1000"
sketch_rnn_train --log_root=models/naam4 --data_dir=datasets/naam4 --hparams="data_set=[diede.npz,lijn.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
```
-->

View file

@ -155,13 +155,13 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"WARNING: Logging before flag parsing goes to stderr.\n", "WARNING: Logging before flag parsing goes to stderr.\n",
"W0825 15:58:50.188074 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", "W0826 11:57:16.532441 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
"\n", "\n",
"/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9\n", "/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9\n",
" warnings.warn(msg)\n", " warnings.warn(msg)\n",
"W0825 15:58:50.811670 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", "W0826 11:57:17.155105 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n",
"\n", "\n",
"W0825 15:58:51.114169 140470926227264 lazy_loader.py:50] \n", "W0826 11:57:17.477846 140006013421376 lazy_loader.py:50] \n",
"The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
"For more information, please see:\n", "For more information, please see:\n",
" * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
@ -169,9 +169,9 @@
" * https://github.com/tensorflow/io (for I/O related ops)\n", " * https://github.com/tensorflow/io (for I/O related ops)\n",
"If you depend on functionality not listed there, please file an issue.\n", "If you depend on functionality not listed there, please file an issue.\n",
"\n", "\n",
"W0825 15:58:51.115457 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", "W0826 11:57:17.478699 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n",
"\n", "\n",
"W0825 15:58:51.116125 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", "W0826 11:57:17.479223 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n",
"\n" "\n"
] ]
} }
@ -270,7 +270,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 139, "execution_count": 14,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@ -278,13 +278,13 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"data_dir = 'datasets/naam4'\n", "data_dir = 'datasets/naam5'\n",
"model_dir = 'models/naam4'" "model_dir = 'models/naam5'"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 140, "execution_count": 15,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -301,7 +301,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 153, "execution_count": 16,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@ -347,7 +347,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 154, "execution_count": 17,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -362,33 +362,20 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"I0825 17:10:28.492715 140470926227264 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", "I0826 11:57:46.998881 140006013421376 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n",
"I0825 17:10:28.666448 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from lijn.npz\n", "I0826 11:57:47.254209 140006013421376 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n",
"I0825 17:10:29.123052 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", "I0826 11:57:47.277675 140006013421376 sketch_rnn_train.py:159] Dataset combined: 783 (261/261/261), avg len 313\n",
"I0825 17:10:29.189439 140470926227264 sketch_rnn_train.py:159] Dataset combined: 1083 (361/361/361), avg len 234\n", "I0826 11:57:47.278856 140006013421376 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n",
"I0825 17:10:29.190502 140470926227264 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n" "I0826 11:57:47.425982 140006013421376 sketch_rnn_train.py:209] normalizing_scale_factor 34.8942.\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"total images <= max_seq_len is 361\n" "total images <= max_seq_len is 261\n",
] "total images <= max_seq_len is 261\n",
}, "total images <= max_seq_len is 261\n"
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0825 17:10:29.655039 140470926227264 sketch_rnn_train.py:209] normalizing_scale_factor 55.2581.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"total images <= max_seq_len is 361\n",
"total images <= max_seq_len is 361\n"
] ]
} }
], ],
@ -398,7 +385,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 155, "execution_count": 18,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -413,18 +400,52 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"I0825 17:10:34.129225 140470926227264 model.py:87] Model using gpu.\n", "W0826 11:57:49.141304 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:62: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n",
"I0825 17:10:34.139610 140470926227264 model.py:175] Input dropout mode = False.\n", "\n",
"I0825 17:10:34.141004 140470926227264 model.py:176] Output dropout mode = False.\n", "W0826 11:57:49.142381 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:65: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.\n",
"I0825 17:10:34.143371 140470926227264 model.py:177] Recurrent dropout mode = False.\n", "\n",
"I0825 17:10:42.551079 140470926227264 model.py:87] Model using gpu.\n", "W0826 11:57:49.143638 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:81: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
"I0825 17:10:42.553035 140470926227264 model.py:175] Input dropout mode = 0.\n", "\n",
"I0825 17:10:42.554712 140470926227264 model.py:176] Output dropout mode = 0.\n", "I0826 11:57:49.144913 140006013421376 model.py:87] Model using gpu.\n",
"I0825 17:10:42.556115 140470926227264 model.py:177] Recurrent dropout mode = 0.\n", "I0826 11:57:49.148379 140006013421376 model.py:175] Input dropout mode = False.\n",
"I0825 17:10:43.679191 140470926227264 model.py:87] Model using gpu.\n", "I0826 11:57:49.148915 140006013421376 model.py:176] Output dropout mode = False.\n",
"I0825 17:10:43.681183 140470926227264 model.py:175] Input dropout mode = 0.\n", "I0826 11:57:49.149318 140006013421376 model.py:177] Recurrent dropout mode = False.\n",
"I0825 17:10:43.682615 140470926227264 model.py:176] Output dropout mode = 0.\n", "W0826 11:57:49.149732 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:190: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
"I0825 17:10:43.683944 140470926227264 model.py:177] Recurrent dropout mode = 0.\n" "\n",
"W0826 11:57:49.161299 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:242: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
"\n",
"W0826 11:57:49.161976 140006013421376 deprecation.py:506] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n",
"W0826 11:57:49.173728 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:253: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
"W0826 11:57:49.513816 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:256: The name tf.nn.xw_plus_b is deprecated. Please use tf.compat.v1.nn.xw_plus_b instead.\n",
"\n",
"W0826 11:57:49.529942 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:266: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Deprecated in favor of operator or tf.math.divide.\n",
"W0826 11:57:49.544464 140006013421376 deprecation.py:506] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:285: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"keep_dims is deprecated, use keepdims instead\n",
"W0826 11:57:49.554083 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:295: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"\n",
"Future major versions of TensorFlow will allow gradients to flow\n",
"into the labels input on backprop by default.\n",
"\n",
"See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n",
"\n",
"W0826 11:57:49.582807 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:351: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
"\n",
"I0826 11:57:50.112523 140006013421376 model.py:87] Model using gpu.\n",
"I0826 11:57:50.113037 140006013421376 model.py:175] Input dropout mode = 0.\n",
"I0826 11:57:50.113482 140006013421376 model.py:176] Output dropout mode = 0.\n",
"I0826 11:57:50.113979 140006013421376 model.py:177] Recurrent dropout mode = 0.\n",
"I0826 11:57:50.264065 140006013421376 model.py:87] Model using gpu.\n",
"I0826 11:57:50.264573 140006013421376 model.py:175] Input dropout mode = 0.\n",
"I0826 11:57:50.264990 140006013421376 model.py:176] Output dropout mode = 0.\n",
"I0826 11:57:50.265478 140006013421376 model.py:177] Recurrent dropout mode = 0.\n"
] ]
} }
], ],
@ -438,7 +459,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 156, "execution_count": 19,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@ -452,7 +473,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 199, "execution_count": 20,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -463,11 +484,20 @@
"outputId": "fb41ce20-4c7f-4991-e9f6-559ea9b34a31" "outputId": "fb41ce20-4c7f-4991-e9f6-559ea9b34a31"
}, },
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"W0826 11:57:52.587875 140006013421376 deprecation.py:323] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use standard file APIs to check for files with this prefix.\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"48.0\n" "29.0\n"
] ]
}, },
{ {
@ -477,7 +507,7 @@
"traceback": [ "traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-199-787fc7b1ddc6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m<ipython-input-20-787fc7b1ddc6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py\u001b[0m in \u001b[0;36mrestore\u001b[0;34m(self, sess, save_path)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheckpoint_management\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m raise ValueError(\"The passed save_path is not a valid checkpoint: \" +\n\u001b[0;32m-> 1278\u001b[0;31m compat.as_text(save_path))\n\u001b[0m\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1280\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Restoring parameters from %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py\u001b[0m in \u001b[0;36mrestore\u001b[0;34m(self, sess, save_path)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheckpoint_management\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m raise ValueError(\"The passed save_path is not a valid checkpoint: \" +\n\u001b[0;32m-> 1278\u001b[0;31m compat.as_text(save_path))\n\u001b[0m\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1280\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Restoring parameters from %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: The passed save_path is not a valid checkpoint: models/naam4/vector-4100" "\u001b[0;31mValueError\u001b[0m: The passed save_path is not a valid checkpoint: models/naam4/vector-4100"
] ]
@ -485,12 +515,12 @@
], ],
"source": [ "source": [
"# loads the weights from checkpoint into our model\n", "# loads the weights from checkpoint into our model\n",
"# load_checkpoint(sess, model_dir)\n", "load_checkpoint(sess, model_dir)\n",
"saver = tf.train.Saver(tf.global_variables())\n", "# saver = tf.train.Saver(tf.global_variables())\n",
"ckpt = tf.train.get_checkpoint_state(model_dir)\n", "# ckpt = tf.train.get_checkpoint_state(model_dir)\n",
"print(int(ckpt.model_checkpoint_path.split(\"-\")[-1])/100)\n", "# print(int(ckpt.model_checkpoint_path.split(\"-\")[-1])/100)\n",
"# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\n", "# # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\n",
"saver.restore(sess, \"models/naam4/vector-4100\")" "# saver.restore(sess, \"models/naam4/vector-4100\")"
] ]
}, },
{ {

View file

@ -17,6 +17,7 @@ from tqdm import tqdm
import re import re
import glob import glob
import math import math
from svgwrite.extensions import Inkscape
# import our command line tools # import our command line tools
@ -108,6 +109,23 @@ argParser.add_argument(
action='store_true', action='store_true',
help='If set, put the last rendition into a separate group' help='If set, put the last rendition into a separate group'
) )
argParser.add_argument(
'--create_grid',
action='store_true',
help='Create a grid with cutting lines'
)
argParser.add_argument(
'--grid_width',
type=int,
default=3,
help='Grid items x'
)
argParser.add_argument(
'--grid_height',
type=int,
default=2,
help='Grid items y'
)
argParser.add_argument( argParser.add_argument(
'--verbose', '--verbose',
'-v', '-v',
@ -293,18 +311,27 @@ def loadCheckpoint(model_dir, nr):
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path) # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}")) saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10 width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10 height = int(re.findall('\d+',args.height)[0])*10
grid_height = args.grid_height if args.create_grid else 1
grid_width = args.grid_width if args.create_grid else 1
# Override given dimension with grid info
page_height = width/10*grid_width
page_width = height/10*grid_height
dims = (f"{page_height}mm", f"{page_width}mm")
# padding = 20 # padding = 20
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width*grid_width} {height*grid_height}")
inkscapeDwg = Inkscape(dwg)
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0) requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)] dwgGroups = [inkscapeDwg.layer(label=f"g{i}") for i in range(requiredGroups)]
for group in dwgGroups: for group in dwgGroups:
dwg.add(group) dwg.add(group)
checkpoints = getCheckpoints(args.model_dir) checkpoints = getCheckpoints(args.model_dir)
item_count = args.rows*args.columns
item_count = args.rows*args.columns*grid_width*grid_height
# factor = dataset_baseheight/ (height/args.rows) # factor = dataset_baseheight/ (height/args.rows)
@ -316,36 +343,68 @@ max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
with tqdm(total=item_count) as pbar: with tqdm(total=item_count) as pbar:
for row in range(args.rows): for grid_pos_x in range(grid_width):
#find the top left point for the strokes grid_x = grid_pos_x * width
min_y = row * (max_height + args.column_padding) + args.page_margin for grid_pos_y in range(grid_height):
for column in range(args.columns): grid_y = grid_pos_y * height
min_x = column * (max_width + args.column_padding) + args.page_margin
item = row*args.columns + column
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
isLast = (row == args.rows-1 and column == args.columns-1) for row in range(args.rows):
#find the top left point for the strokes
min_y = grid_y + row * (max_height + args.column_padding) + args.page_margin
for column in range(args.columns):
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
item = row*args.columns + column
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
if isLast and args.last_is_target: isLast = (row == args.rows-1 and column == args.columns-1)
strokes = target_stroke
else:
strokes = decode(target_z, temperature=1)
if args.last_in_group and isLast: if isLast and args.last_is_target:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) strokes = target_stroke
dwgGroups[-1].add(path) else:
elif args.split_paths: strokes = decode(target_z, temperature=1)
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height) # strokes = target_stroke
i = 0
for path in paths:
group = dwgGroups[i % args.nr_of_paths]
i+=1
group.add(path)
else:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[0].add(path)
pbar.update() if args.last_in_group and isLast:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[-1].add(path)
elif args.split_paths:
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
i = 0
for path in paths:
group = dwgGroups[i % args.nr_of_paths]
i+=1
group.add(path)
else:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[0].add(path)
pbar.update()
if args.create_grid:
logger.info("Create grid")
grid_length = 50
grid_group = inkscapeDwg.layer(label='grid')
with tqdm(total=(grid_width+1)*(grid_height+1)) as pbar:
for i in range(grid_width + 1):
for j in range(grid_height + 1):
sx = i * width
sy = j * height - grid_length
sx2 = sx
sy2 = sy+2*grid_length
p = f"M{sx},{sy} L{sx2},{sy2}"
path = dwg.path(p).stroke('black',1).fill("none")
dwg.add(path)
sx = i * width - grid_length
sy = j * height
sx2 = sx+ 2*grid_length
sy2 = sy
p = f"M{sx},{sy} L{sx2},{sy2}"
path = dwg.path(p).stroke('black',1).fill("none")
grid_group.add(path)
pbar.update()
dwg.add(grid_group)
dwg.save() dwg.save()