diff --git a/README.md b/README.md index 1923bcb..1f3ab1b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,36 @@ -Successfull on naam-simple: +Turn numbered svgs into usable arrays: ``` -sketch_rnn_train --log_root=models/naam-simple --data_dir=datasets/naam-simple --hparams="data_set=[diede.npz],dec_model=layer_norm,dec_rnn_size=200,enc_model=layer_norm,enc_rnn_size=200,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=False,num_steps=1000" +python create_dataset.py --dataset_dir datasets/naam6/ ``` + +Train algorithm: (save often, as we'll use the intermediate steps) +``` +sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000" +``` + +Generate a card: + +``` +python create_card.py --data_dir datasets/naam6 --model_dir models/naam6 --max_checkpoint_factor .8 --columns 5 --rows 13 --split_paths --last_is_target --last_in_group +``` + +max_checkpoint_factor +: set was trained for too many iterations in order to generate a nice card (~half of the card looks already smooth), by lowering this factor, we use eg. only the first 80% (.8) iteration + +split_paths +: Drawings that consist of mulitple strokes are split over paths, which are split over a given number of groups (see nr_of_paths) + +last_is_target +: Last item (bottom right) is not generated but hand picked from the dataset (see target_sample) + + +last_in_group +: Puts the last drawing in a separate group + + diff --git a/Sketch_RNN.ipynb b/Sketch_RNN.ipynb index b17cc9a..b789e66 100644 --- a/Sketch_RNN.ipynb +++ b/Sketch_RNN.ipynb @@ -155,13 +155,13 @@ "output_type": "stream", "text": [ "WARNING: Logging before flag parsing goes to stderr.\n", - "W0825 15:58:50.188074 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", + "W0826 11:57:16.532441 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", "\n", "/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9\n", " warnings.warn(msg)\n", - "W0825 15:58:50.811670 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", + "W0826 11:57:17.155105 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", "\n", - "W0825 15:58:51.114169 140470926227264 lazy_loader.py:50] \n", + "W0826 11:57:17.477846 140006013421376 lazy_loader.py:50] \n", "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", "For more information, please see:\n", " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", @@ -169,9 +169,9 @@ " * https://github.com/tensorflow/io (for I/O related ops)\n", "If you depend on functionality not listed there, please file an issue.\n", "\n", - "W0825 15:58:51.115457 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", + "W0826 11:57:17.478699 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", "\n", - "W0825 15:58:51.116125 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", + "W0826 11:57:17.479223 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", "\n" ] } @@ -270,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 14, "metadata": { "colab": {}, "colab_type": "code", @@ -278,13 +278,13 @@ }, "outputs": [], "source": [ - "data_dir = 'datasets/naam4'\n", - "model_dir = 'models/naam4'" + "data_dir = 'datasets/naam5'\n", + "model_dir = 'models/naam5'" ] }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -301,7 +301,7 @@ }, { "cell_type": "code", - "execution_count": 153, + "execution_count": 16, "metadata": { "colab": {}, "colab_type": "code", @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 154, + "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -362,33 +362,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "I0825 17:10:28.492715 140470926227264 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", - "I0825 17:10:28.666448 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from lijn.npz\n", - "I0825 17:10:29.123052 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", - "I0825 17:10:29.189439 140470926227264 sketch_rnn_train.py:159] Dataset combined: 1083 (361/361/361), avg len 234\n", - "I0825 17:10:29.190502 140470926227264 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n" + "I0826 11:57:46.998881 140006013421376 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", + "I0826 11:57:47.254209 140006013421376 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", + "I0826 11:57:47.277675 140006013421376 sketch_rnn_train.py:159] Dataset combined: 783 (261/261/261), avg len 313\n", + "I0826 11:57:47.278856 140006013421376 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n", + "I0826 11:57:47.425982 140006013421376 sketch_rnn_train.py:209] normalizing_scale_factor 34.8942.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "total images <= max_seq_len is 361\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "I0825 17:10:29.655039 140470926227264 sketch_rnn_train.py:209] normalizing_scale_factor 55.2581.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "total images <= max_seq_len is 361\n", - "total images <= max_seq_len is 361\n" + "total images <= max_seq_len is 261\n", + "total images <= max_seq_len is 261\n", + "total images <= max_seq_len is 261\n" ] } ], @@ -398,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 155, + "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -413,18 +400,52 @@ "name": "stderr", "output_type": "stream", "text": [ - "I0825 17:10:34.129225 140470926227264 model.py:87] Model using gpu.\n", - "I0825 17:10:34.139610 140470926227264 model.py:175] Input dropout mode = False.\n", - "I0825 17:10:34.141004 140470926227264 model.py:176] Output dropout mode = False.\n", - "I0825 17:10:34.143371 140470926227264 model.py:177] Recurrent dropout mode = False.\n", - "I0825 17:10:42.551079 140470926227264 model.py:87] Model using gpu.\n", - "I0825 17:10:42.553035 140470926227264 model.py:175] Input dropout mode = 0.\n", - "I0825 17:10:42.554712 140470926227264 model.py:176] Output dropout mode = 0.\n", - "I0825 17:10:42.556115 140470926227264 model.py:177] Recurrent dropout mode = 0.\n", - "I0825 17:10:43.679191 140470926227264 model.py:87] Model using gpu.\n", - "I0825 17:10:43.681183 140470926227264 model.py:175] Input dropout mode = 0.\n", - "I0825 17:10:43.682615 140470926227264 model.py:176] Output dropout mode = 0.\n", - "I0825 17:10:43.683944 140470926227264 model.py:177] Recurrent dropout mode = 0.\n" + "W0826 11:57:49.141304 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:62: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", + "\n", + "W0826 11:57:49.142381 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:65: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.\n", + "\n", + "W0826 11:57:49.143638 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:81: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n", + "\n", + "I0826 11:57:49.144913 140006013421376 model.py:87] Model using gpu.\n", + "I0826 11:57:49.148379 140006013421376 model.py:175] Input dropout mode = False.\n", + "I0826 11:57:49.148915 140006013421376 model.py:176] Output dropout mode = False.\n", + "I0826 11:57:49.149318 140006013421376 model.py:177] Recurrent dropout mode = False.\n", + "W0826 11:57:49.149732 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:190: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", + "\n", + "W0826 11:57:49.161299 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:242: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n", + "\n", + "W0826 11:57:49.161976 140006013421376 deprecation.py:506] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Call initializer instance with the dtype argument instead of passing it to the constructor\n", + "W0826 11:57:49.173728 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:253: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n", + "W0826 11:57:49.513816 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:256: The name tf.nn.xw_plus_b is deprecated. Please use tf.compat.v1.nn.xw_plus_b instead.\n", + "\n", + "W0826 11:57:49.529942 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:266: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Deprecated in favor of operator or tf.math.divide.\n", + "W0826 11:57:49.544464 140006013421376 deprecation.py:506] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:285: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "keep_dims is deprecated, use keepdims instead\n", + "W0826 11:57:49.554083 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:295: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "\n", + "Future major versions of TensorFlow will allow gradients to flow\n", + "into the labels input on backprop by default.\n", + "\n", + "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n", + "\n", + "W0826 11:57:49.582807 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:351: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n", + "\n", + "I0826 11:57:50.112523 140006013421376 model.py:87] Model using gpu.\n", + "I0826 11:57:50.113037 140006013421376 model.py:175] Input dropout mode = 0.\n", + "I0826 11:57:50.113482 140006013421376 model.py:176] Output dropout mode = 0.\n", + "I0826 11:57:50.113979 140006013421376 model.py:177] Recurrent dropout mode = 0.\n", + "I0826 11:57:50.264065 140006013421376 model.py:87] Model using gpu.\n", + "I0826 11:57:50.264573 140006013421376 model.py:175] Input dropout mode = 0.\n", + "I0826 11:57:50.264990 140006013421376 model.py:176] Output dropout mode = 0.\n", + "I0826 11:57:50.265478 140006013421376 model.py:177] Recurrent dropout mode = 0.\n" ] } ], @@ -438,7 +459,7 @@ }, { "cell_type": "code", - "execution_count": 156, + "execution_count": 19, "metadata": { "colab": {}, "colab_type": "code", @@ -452,7 +473,7 @@ }, { "cell_type": "code", - "execution_count": 199, + "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -463,11 +484,20 @@ "outputId": "fb41ce20-4c7f-4991-e9f6-559ea9b34a31" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0826 11:57:52.587875 140006013421376 deprecation.py:323] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use standard file APIs to check for files with this prefix.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "48.0\n" + "29.0\n" ] }, { @@ -477,7 +507,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py\u001b[0m in \u001b[0;36mrestore\u001b[0;34m(self, sess, save_path)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheckpoint_management\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m raise ValueError(\"The passed save_path is not a valid checkpoint: \" +\n\u001b[0;32m-> 1278\u001b[0;31m compat.as_text(save_path))\n\u001b[0m\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1280\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Restoring parameters from %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: The passed save_path is not a valid checkpoint: models/naam4/vector-4100" ] @@ -485,12 +515,12 @@ ], "source": [ "# loads the weights from checkpoint into our model\n", - "# load_checkpoint(sess, model_dir)\n", - "saver = tf.train.Saver(tf.global_variables())\n", - "ckpt = tf.train.get_checkpoint_state(model_dir)\n", - "print(int(ckpt.model_checkpoint_path.split(\"-\")[-1])/100)\n", - "# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\n", - "saver.restore(sess, \"models/naam4/vector-4100\")" + "load_checkpoint(sess, model_dir)\n", + "# saver = tf.train.Saver(tf.global_variables())\n", + "# ckpt = tf.train.get_checkpoint_state(model_dir)\n", + "# print(int(ckpt.model_checkpoint_path.split(\"-\")[-1])/100)\n", + "# # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\n", + "# saver.restore(sess, \"models/naam4/vector-4100\")" ] }, { diff --git a/create_card.py b/create_card.py index 29e14f6..530e991 100644 --- a/create_card.py +++ b/create_card.py @@ -17,6 +17,7 @@ from tqdm import tqdm import re import glob import math +from svgwrite.extensions import Inkscape # import our command line tools @@ -108,6 +109,23 @@ argParser.add_argument( action='store_true', help='If set, put the last rendition into a separate group' ) +argParser.add_argument( + '--create_grid', + action='store_true', + help='Create a grid with cutting lines' + ) +argParser.add_argument( + '--grid_width', + type=int, + default=3, + help='Grid items x' + ) +argParser.add_argument( + '--grid_height', + type=int, + default=2, + help='Grid items y' + ) argParser.add_argument( '--verbose', '-v', @@ -293,18 +311,27 @@ def loadCheckpoint(model_dir, nr): # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(model_dir, f"vector-{nr}")) -dims = (args.width, args.height) width = int(re.findall('\d+',args.width)[0])*10 height = int(re.findall('\d+',args.height)[0])*10 +grid_height = args.grid_height if args.create_grid else 1 +grid_width = args.grid_width if args.create_grid else 1 + +# Override given dimension with grid info +page_height = width/10*grid_width +page_width = height/10*grid_height +dims = (f"{page_height}mm", f"{page_width}mm") + # padding = 20 -dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") +dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width*grid_width} {height*grid_height}") +inkscapeDwg = Inkscape(dwg) requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0) -dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)] +dwgGroups = [inkscapeDwg.layer(label=f"g{i}") for i in range(requiredGroups)] for group in dwgGroups: dwg.add(group) checkpoints = getCheckpoints(args.model_dir) -item_count = args.rows*args.columns + +item_count = args.rows*args.columns*grid_width*grid_height # factor = dataset_baseheight/ (height/args.rows) @@ -316,36 +343,68 @@ max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1)) max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows with tqdm(total=item_count) as pbar: - for row in range(args.rows): - #find the top left point for the strokes - min_y = row * (max_height + args.column_padding) + args.page_margin - for column in range(args.columns): - min_x = column * (max_width + args.column_padding) + args.page_margin - item = row*args.columns + column - checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints)) - checkpoint = checkpoints[checkpoint_idx] - loadCheckpoint(args.model_dir, checkpoint) + for grid_pos_x in range(grid_width): + grid_x = grid_pos_x * width + for grid_pos_y in range(grid_height): + grid_y = grid_pos_y * height - isLast = (row == args.rows-1 and column == args.columns-1) + for row in range(args.rows): + #find the top left point for the strokes + min_y = grid_y + row * (max_height + args.column_padding) + args.page_margin + for column in range(args.columns): + min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin + item = row*args.columns + column + checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints)) + checkpoint = checkpoints[checkpoint_idx] + loadCheckpoint(args.model_dir, checkpoint) - if isLast and args.last_is_target: - strokes = target_stroke - else: - strokes = decode(target_z, temperature=1) + isLast = (row == args.rows-1 and column == args.columns-1) - if args.last_in_group and isLast: - path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) - dwgGroups[-1].add(path) - elif args.split_paths: - paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height) - i = 0 - for path in paths: - group = dwgGroups[i % args.nr_of_paths] - i+=1 - group.add(path) - else: - path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) - dwgGroups[0].add(path) + if isLast and args.last_is_target: + strokes = target_stroke + else: + strokes = decode(target_z, temperature=1) + # strokes = target_stroke - pbar.update() + if args.last_in_group and isLast: + path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) + dwgGroups[-1].add(path) + elif args.split_paths: + paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height) + i = 0 + for path in paths: + group = dwgGroups[i % args.nr_of_paths] + i+=1 + group.add(path) + else: + path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) + dwgGroups[0].add(path) + + pbar.update() + +if args.create_grid: + logger.info("Create grid") + + grid_length = 50 + grid_group = inkscapeDwg.layer(label='grid') + with tqdm(total=(grid_width+1)*(grid_height+1)) as pbar: + for i in range(grid_width + 1): + for j in range(grid_height + 1): + sx = i * width + sy = j * height - grid_length + sx2 = sx + sy2 = sy+2*grid_length + p = f"M{sx},{sy} L{sx2},{sy2}" + path = dwg.path(p).stroke('black',1).fill("none") + dwg.add(path) + sx = i * width - grid_length + sy = j * height + sx2 = sx+ 2*grid_length + sy2 = sy + p = f"M{sx},{sy} L{sx2},{sy2}" + path = dwg.path(p).stroke('black',1).fill("none") + grid_group.add(path) + + pbar.update() + dwg.add(grid_group) dwg.save()