surveilling-surveillance/detection/lightning/util.py
2021-05-20 13:22:04 -07:00

34 lines
1.1 KiB
Python

"""Define Logger class for logging information to stdout and disk."""
import json
import os
from os.path import join
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
def get_ckpt_dir(save_path, exp_name):
return os.path.join(save_path, exp_name, "ckpts")
def get_ckpt_callback(save_path, exp_name, monitor="val_loss", mode="min"):
ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
return ModelCheckpoint(filepath=ckpt_dir,
save_top_k=1,
verbose=True,
monitor=monitor,
mode=mode,
prefix='')
def get_early_stop_callback(patience=10):
return EarlyStopping(monitor='val_loss',
patience=patience,
verbose=True,
mode='min')
def get_logger(save_path, exp_name):
exp_dir = os.path.join(save_path, exp_name)
return TestTubeLogger(save_dir=exp_dir,
name='lightning_logs',
version="0")