surveilling-surveillance/detection/lightning/util.py

37 lines
1.2 KiB
Python
Raw Normal View History

2021-05-20 22:20:48 +02:00
"""Define Logger class for logging information to stdout and disk."""
import json
import os
from os.path import join
# from pytorch_lightning.loggers.test_tube import TestTubeLogger
from pytorch_lightning.loggers import CSVLogger
2021-05-20 22:20:48 +02:00
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
def get_ckpt_dir(save_path, exp_name):
return os.path.join(save_path, exp_name, "ckpts")
def get_ckpt_callback(save_path, exp_name, monitor="val_loss", mode="min"):
ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
return ModelCheckpoint(dirpath=ckpt_dir,
2021-05-20 22:20:48 +02:00
save_top_k=1,
verbose=True,
monitor=monitor,
mode=mode,
# prefix=''
)
2021-05-20 22:20:48 +02:00
def get_early_stop_callback(patience=10):
return EarlyStopping(monitor='val_loss',
patience=patience,
verbose=True,
mode='min')
def get_logger(save_path, exp_name):
exp_dir = os.path.join(save_path, exp_name)
return CSVLogger(save_dir=exp_dir,
2021-05-20 22:20:48 +02:00
name='lightning_logs',
version="0")