Small fixes and changed defaults
This commit is contained in:
parent
40e6753ded
commit
b95f7b941b
3 changed files with 94 additions and 99 deletions
|
@ -5,7 +5,8 @@ python create_dataset.py --dataset_dir datasets/naam6/
|
||||||
|
|
||||||
Train algorithm: (save often, as we'll use the intermediate steps)
|
Train algorithm: (save often, as we'll use the intermediate steps)
|
||||||
```
|
```
|
||||||
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
#sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
||||||
|
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000
|
||||||
```
|
```
|
||||||
|
|
||||||
Generate a card:
|
Generate a card:
|
||||||
|
@ -34,3 +35,8 @@ Successfull on naam4:
|
||||||
sketch_rnn_train --log_root=models/naam4 --data_dir=datasets/naam4 --hparams="data_set=[diede.npz,lijn.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
sketch_rnn_train --log_root=models/naam4 --data_dir=datasets/naam4 --hparams="data_set=[diede.npz,lijn.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
||||||
```
|
```
|
||||||
-->
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
naam4:
|
||||||
|
strokes & blokletters: 101-360
|
||||||
|
strokes only: 101-259
|
||||||
|
|
172
Sketch_RNN.ipynb
172
Sketch_RNN.ipynb
File diff suppressed because one or more lines are too long
|
@ -55,7 +55,7 @@ argParser.add_argument(
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--width',
|
'--width',
|
||||||
type=str,
|
type=str,
|
||||||
default='100mm',
|
default='99mm',
|
||||||
)
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--height',
|
'--height',
|
||||||
|
@ -330,8 +330,12 @@ for group in dwgGroups:
|
||||||
dwg.add(group)
|
dwg.add(group)
|
||||||
|
|
||||||
checkpoints = getCheckpoints(args.model_dir)
|
checkpoints = getCheckpoints(args.model_dir)
|
||||||
|
items_per_page = args.rows*args.columns
|
||||||
|
item_count = items_per_page*grid_width*grid_height
|
||||||
|
|
||||||
item_count = args.rows*args.columns*grid_width*grid_height
|
logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}")
|
||||||
|
max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1)
|
||||||
|
logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}")
|
||||||
|
|
||||||
# factor = dataset_baseheight/ (height/args.rows)
|
# factor = dataset_baseheight/ (height/args.rows)
|
||||||
|
|
||||||
|
@ -354,7 +358,8 @@ with tqdm(total=item_count) as pbar:
|
||||||
for column in range(args.columns):
|
for column in range(args.columns):
|
||||||
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
|
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
|
||||||
item = row*args.columns + column
|
item = row*args.columns + column
|
||||||
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints))
|
||||||
|
logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}")
|
||||||
checkpoint = checkpoints[checkpoint_idx]
|
checkpoint = checkpoints[checkpoint_idx]
|
||||||
loadCheckpoint(args.model_dir, checkpoint)
|
loadCheckpoint(args.model_dir, checkpoint)
|
||||||
|
|
||||||
|
@ -396,7 +401,7 @@ if args.create_grid:
|
||||||
sy2 = sy+2*grid_length
|
sy2 = sy+2*grid_length
|
||||||
p = f"M{sx},{sy} L{sx2},{sy2}"
|
p = f"M{sx},{sy} L{sx2},{sy2}"
|
||||||
path = dwg.path(p).stroke('black',1).fill("none")
|
path = dwg.path(p).stroke('black',1).fill("none")
|
||||||
dwg.add(path)
|
grid_group.add(path)
|
||||||
sx = i * width - grid_length
|
sx = i * width - grid_length
|
||||||
sy = j * height
|
sy = j * height
|
||||||
sx2 = sx+ 2*grid_length
|
sx2 = sx+ 2*grid_length
|
||||||
|
|
Loading…
Reference in a new issue