Small fixes and changed defaults

This commit is contained in:
Ruben van de Ven 2019-08-27 20:41:06 +02:00
parent 40e6753ded
commit b95f7b941b
3 changed files with 94 additions and 99 deletions

View file

@ -5,7 +5,8 @@ python create_dataset.py --dataset_dir datasets/naam6/
Train algorithm: (save often, as we'll use the intermediate steps)
```
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
#sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000
```
Generate a card:
@ -34,3 +35,8 @@ Successfull on naam4:
sketch_rnn_train --log_root=models/naam4 --data_dir=datasets/naam4 --hparams="data_set=[diede.npz,lijn.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
```
-->
naam4:
strokes & blokletters: 101-360
strokes only: 101-259

File diff suppressed because one or more lines are too long

View file

@ -55,7 +55,7 @@ argParser.add_argument(
argParser.add_argument(
'--width',
type=str,
default='100mm',
default='99mm',
)
argParser.add_argument(
'--height',
@ -330,8 +330,12 @@ for group in dwgGroups:
dwg.add(group)
checkpoints = getCheckpoints(args.model_dir)
items_per_page = args.rows*args.columns
item_count = items_per_page*grid_width*grid_height
item_count = args.rows*args.columns*grid_width*grid_height
logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}")
max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1)
logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}")
# factor = dataset_baseheight/ (height/args.rows)
@ -354,7 +358,8 @@ with tqdm(total=item_count) as pbar:
for column in range(args.columns):
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
item = row*args.columns + column
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints))
logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}")
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
@ -396,7 +401,7 @@ if args.create_grid:
sy2 = sy+2*grid_length
p = f"M{sx},{sy} L{sx2},{sy2}"
path = dwg.path(p).stroke('black',1).fill("none")
dwg.add(path)
grid_group.add(path)
sx = i * width - grid_length
sy = j * height
sx2 = sx+ 2*grid_length