Small fixes and changed defaults
This commit is contained in:
parent
40e6753ded
commit
b95f7b941b
3 changed files with 94 additions and 99 deletions
|
@ -5,7 +5,8 @@ python create_dataset.py --dataset_dir datasets/naam6/
|
|||
|
||||
Train algorithm: (save often, as we'll use the intermediate steps)
|
||||
```
|
||||
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
||||
#sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
||||
sketch_rnn_train --log_root=models/naam6 --data_dir=datasets/naam6 --hparams="data_set=[diede.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=50,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000
|
||||
```
|
||||
|
||||
Generate a card:
|
||||
|
@ -34,3 +35,8 @@ Successfull on naam4:
|
|||
sketch_rnn_train --log_root=models/naam4 --data_dir=datasets/naam4 --hparams="data_set=[diede.npz,lijn.npz,blokletters.npz],dec_model=layer_norm,dec_rnn_size=450,enc_model=layer_norm,enc_rnn_size=300,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=True,num_steps=5000"
|
||||
```
|
||||
-->
|
||||
|
||||
|
||||
naam4:
|
||||
strokes & blokletters: 101-360
|
||||
strokes only: 101-259
|
||||
|
|
172
Sketch_RNN.ipynb
172
Sketch_RNN.ipynb
File diff suppressed because one or more lines are too long
|
@ -55,7 +55,7 @@ argParser.add_argument(
|
|||
argParser.add_argument(
|
||||
'--width',
|
||||
type=str,
|
||||
default='100mm',
|
||||
default='99mm',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--height',
|
||||
|
@ -330,8 +330,12 @@ for group in dwgGroups:
|
|||
dwg.add(group)
|
||||
|
||||
checkpoints = getCheckpoints(args.model_dir)
|
||||
items_per_page = args.rows*args.columns
|
||||
item_count = items_per_page*grid_width*grid_height
|
||||
|
||||
item_count = args.rows*args.columns*grid_width*grid_height
|
||||
logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}")
|
||||
max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1)
|
||||
logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}")
|
||||
|
||||
# factor = dataset_baseheight/ (height/args.rows)
|
||||
|
||||
|
@ -354,7 +358,8 @@ with tqdm(total=item_count) as pbar:
|
|||
for column in range(args.columns):
|
||||
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
|
||||
item = row*args.columns + column
|
||||
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
||||
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints))
|
||||
logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}")
|
||||
checkpoint = checkpoints[checkpoint_idx]
|
||||
loadCheckpoint(args.model_dir, checkpoint)
|
||||
|
||||
|
@ -396,7 +401,7 @@ if args.create_grid:
|
|||
sy2 = sy+2*grid_length
|
||||
p = f"M{sx},{sy} L{sx2},{sy2}"
|
||||
path = dwg.path(p).stroke('black',1).fill("none")
|
||||
dwg.add(path)
|
||||
grid_group.add(path)
|
||||
sx = i * width - grid_length
|
||||
sy = j * height
|
||||
sx2 = sx+ 2*grid_length
|
||||
|
|
Loading…
Reference in a new issue