diff --git a/README.md b/README.md index 1f3ab1b..cf1164c 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,8 @@ python create_dataset.py --dataset_dir datasets/naam6/ Train algorithm: (save often, as we'll use the intermediate steps) ``` -sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000" +#sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000" +sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000 ``` Generate a card: @@ -34,3 +35,8 @@ Successfull on naam4: sketch_rnn_train --log_root=models/naam4 --data_dir=datasets/naam4 --hparams="data_set=[diede.npz,lijn.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000" ``` --> + + +naam4: +strokes & blokletters: 101-360 +strokes only: 101-259 diff --git a/Sketch_RNN.ipynb b/Sketch_RNN.ipynb index b789e66..b91f044 100644 --- a/Sketch_RNN.ipynb +++ b/Sketch_RNN.ipynb @@ -155,13 +155,13 @@ "output_type": "stream", "text": [ "WARNING: Logging before flag parsing goes to stderr.\n", - "W0826 11:57:16.532441 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", + "W0827 17:43:31.091431 140714736174912 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", "\n", "/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9\n", " warnings.warn(msg)\n", - "W0826 11:57:17.155105 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", + "W0827 17:43:31.889853 140714736174912 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", "\n", - "W0826 11:57:17.477846 140006013421376 lazy_loader.py:50] \n", + "W0827 17:43:32.257223 140714736174912 lazy_loader.py:50] \n", "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", "For more information, please see:\n", " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", @@ -169,9 +169,9 @@ " * https://github.com/tensorflow/io (for I/O related ops)\n", "If you depend on functionality not listed there, please file an issue.\n", "\n", - "W0826 11:57:17.478699 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", + "W0827 17:43:32.258507 140714736174912 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", "\n", - "W0826 11:57:17.479223 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", + "W0827 17:43:32.259193 140714736174912 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", "\n" ] } @@ -270,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 20, "metadata": { "colab": {}, "colab_type": "code", @@ -278,13 +278,13 @@ }, "outputs": [], "source": [ - "data_dir = 'datasets/naam5'\n", - "model_dir = 'models/naam5'" + "data_dir = 'datasets/naam4'\n", + "model_dir = 'models/naam4'" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -301,7 +301,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "metadata": { "colab": {}, "colab_type": "code", @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -362,20 +362,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "I0826 11:57:46.998881 140006013421376 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", - "I0826 11:57:47.254209 140006013421376 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", - "I0826 11:57:47.277675 140006013421376 sketch_rnn_train.py:159] Dataset combined: 783 (261/261/261), avg len 313\n", - "I0826 11:57:47.278856 140006013421376 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n", - "I0826 11:57:47.425982 140006013421376 sketch_rnn_train.py:209] normalizing_scale_factor 34.8942.\n" + "I0827 17:54:57.952397 140714736174912 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", + "I0827 17:54:58.126416 140714736174912 sketch_rnn_train.py:142] Loaded 100/100/100 from lijn.npz\n", + "I0827 17:54:58.167372 140714736174912 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", + "I0827 17:54:58.205663 140714736174912 sketch_rnn_train.py:159] Dataset combined: 1083 (361/361/361), avg len 234\n", + "I0827 17:54:58.208060 140714736174912 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n", + "I0827 17:54:58.410045 140714736174912 sketch_rnn_train.py:209] normalizing_scale_factor 55.2581.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "total images <= max_seq_len is 261\n", - "total images <= max_seq_len is 261\n", - "total images <= max_seq_len is 261\n" + "total images <= max_seq_len is 361\n", + "total images <= max_seq_len is 361\n", + "total images <= max_seq_len is 361\n" ] } ], @@ -385,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -400,52 +401,24 @@ "name": "stderr", "output_type": "stream", "text": [ - "W0826 11:57:49.141304 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:62: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", - "\n", - "W0826 11:57:49.142381 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:65: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.\n", - "\n", - "W0826 11:57:49.143638 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:81: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n", - "\n", - "I0826 11:57:49.144913 140006013421376 model.py:87] Model using gpu.\n", - "I0826 11:57:49.148379 140006013421376 model.py:175] Input dropout mode = False.\n", - "I0826 11:57:49.148915 140006013421376 model.py:176] Output dropout mode = False.\n", - "I0826 11:57:49.149318 140006013421376 model.py:177] Recurrent dropout mode = False.\n", - "W0826 11:57:49.149732 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:190: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", - "\n", - "W0826 11:57:49.161299 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:242: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n", - "\n", - "W0826 11:57:49.161976 140006013421376 deprecation.py:506] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", + "I0827 17:54:58.430566 140714736174912 model.py:87] Model using gpu.\n", + "I0827 17:54:58.436558 140714736174912 model.py:175] Input dropout mode = False.\n", + "I0827 17:54:58.437577 140714736174912 model.py:176] Output dropout mode = False.\n", + "I0827 17:54:58.438252 140714736174912 model.py:177] Recurrent dropout mode = False.\n", + "W0827 17:54:58.446333 140714736174912 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:100: bidirectional_dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", - "Call initializer instance with the dtype argument instead of passing it to the constructor\n", - "W0826 11:57:49.173728 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:253: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", + "Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API\n", + "W0827 17:54:58.743351 140714736174912 deprecation.py:323] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/ops/rnn.py:244: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", - "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n", - "W0826 11:57:49.513816 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:256: The name tf.nn.xw_plus_b is deprecated. Please use tf.compat.v1.nn.xw_plus_b instead.\n", - "\n", - "W0826 11:57:49.529942 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:266: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Deprecated in favor of operator or tf.math.divide.\n", - "W0826 11:57:49.544464 140006013421376 deprecation.py:506] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:285: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "keep_dims is deprecated, use keepdims instead\n", - "W0826 11:57:49.554083 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:295: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "\n", - "Future major versions of TensorFlow will allow gradients to flow\n", - "into the labels input on backprop by default.\n", - "\n", - "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n", - "\n", - "W0826 11:57:49.582807 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:351: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n", - "\n", - "I0826 11:57:50.112523 140006013421376 model.py:87] Model using gpu.\n", - "I0826 11:57:50.113037 140006013421376 model.py:175] Input dropout mode = 0.\n", - "I0826 11:57:50.113482 140006013421376 model.py:176] Output dropout mode = 0.\n", - "I0826 11:57:50.113979 140006013421376 model.py:177] Recurrent dropout mode = 0.\n", - "I0826 11:57:50.264065 140006013421376 model.py:87] Model using gpu.\n", - "I0826 11:57:50.264573 140006013421376 model.py:175] Input dropout mode = 0.\n", - "I0826 11:57:50.264990 140006013421376 model.py:176] Output dropout mode = 0.\n", - "I0826 11:57:50.265478 140006013421376 model.py:177] Recurrent dropout mode = 0.\n" + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", + "I0827 17:55:01.114828 140714736174912 model.py:87] Model using gpu.\n", + "I0827 17:55:01.115479 140714736174912 model.py:175] Input dropout mode = 0.\n", + "I0827 17:55:01.116087 140714736174912 model.py:176] Output dropout mode = 0.\n", + "I0827 17:55:01.116714 140714736174912 model.py:177] Recurrent dropout mode = 0.\n", + "I0827 17:55:01.471915 140714736174912 model.py:87] Model using gpu.\n", + "I0827 17:55:01.472840 140714736174912 model.py:175] Input dropout mode = 0.\n", + "I0827 17:55:01.473559 140714736174912 model.py:176] Output dropout mode = 0.\n", + "I0827 17:55:01.474178 140714736174912 model.py:177] Recurrent dropout mode = 0.\n" ] } ], @@ -459,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, "metadata": { "colab": {}, "colab_type": "code", @@ -473,7 +446,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -488,28 +461,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "W0826 11:57:52.587875 140006013421376 deprecation.py:323] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Use standard file APIs to check for files with this prefix.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "29.0\n" - ] - }, - { - "ename": "ValueError", - "evalue": "The passed save_path is not a valid checkpoint: models/naam4/vector-4100", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py\u001b[0m in \u001b[0;36mrestore\u001b[0;34m(self, sess, save_path)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheckpoint_management\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m raise ValueError(\"The passed save_path is not a valid checkpoint: \" +\n\u001b[0;32m-> 1278\u001b[0;31m compat.as_text(save_path))\n\u001b[0m\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1280\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Restoring parameters from %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: The passed save_path is not a valid checkpoint: models/naam4/vector-4100" + "I0827 17:55:02.746215 140714736174912 sketch_rnn_train.py:241] Loading model models/naam4/vector-4800.\n", + "I0827 17:55:02.749120 140714736174912 saver.py:1280] Restoring parameters from models/naam4/vector-4800\n" ] } ], @@ -535,7 +488,7 @@ }, { "cell_type": "code", - "execution_count": 159, + "execution_count": 27, "metadata": { "colab": {}, "colab_type": "code", @@ -555,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": 160, + "execution_count": 28, "metadata": { "colab": {}, "colab_type": "code", @@ -576,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 201, + "execution_count": 69, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -587,10 +540,17 @@ "outputId": "c8e9a1c3-28db-4263-ac67-62ffece1e1e0" }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "260\n" + ] + }, { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "" @@ -602,9 +562,12 @@ ], "source": [ "# get a sample drawing from the test set, and render it to .svg\n", + "i = random.randint(101,360)\n", + "i=260\n", + "print(i)\n", "stroke = test_set.random_sample()\n", - "stroke=test_set.strokes[252]\n", - "draw_strokes(stroke)" + "stroke=test_set.strokes[i]\n", + "draw_strokes(stroke)\n" ] }, { @@ -617,6 +580,27 @@ "Let's try to encode the sample stroke into latent vector $z$" ] }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help on method random_sample in module magenta.models.sketch_rnn.utils:\n", + "\n", + "random_sample() method of magenta.models.sketch_rnn.utils.DataLoader instance\n", + " Return a random sample, in stroke-3 format as used by draw_strokes.\n", + "\n" + ] + } + ], + "source": [ + "help(test_set.random_sample)" + ] + }, { "cell_type": "code", "execution_count": 190, diff --git a/create_card.py b/create_card.py index 530e991..c63de9b 100644 --- a/create_card.py +++ b/create_card.py @@ -55,7 +55,7 @@ argParser.add_argument( argParser.add_argument( '--width', type=str, - default='100mm', + default='99mm', ) argParser.add_argument( '--height', @@ -330,8 +330,12 @@ for group in dwgGroups: dwg.add(group) checkpoints = getCheckpoints(args.model_dir) +items_per_page = args.rows*args.columns +item_count = items_per_page*grid_width*grid_height -item_count = args.rows*args.columns*grid_width*grid_height +logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}") +max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1) +logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}") # factor = dataset_baseheight/ (height/args.rows) @@ -354,7 +358,8 @@ with tqdm(total=item_count) as pbar: for column in range(args.columns): min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin item = row*args.columns + column - checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints)) + checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints)) + logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}") checkpoint = checkpoints[checkpoint_idx] loadCheckpoint(args.model_dir, checkpoint) @@ -396,7 +401,7 @@ if args.create_grid: sy2 = sy+2*grid_length p = f"M{sx},{sy} L{sx2},{sy2}" path = dwg.path(p).stroke('black',1).fill("none") - dwg.add(path) + grid_group.add(path) sx = i * width - grid_length sy = j * height sx2 = sx+ 2*grid_length