2021-05-20 22:20:48 +02:00
|
|
|
"""Define Logger class for logging information to stdout and disk."""
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
from os.path import join
|
2024-02-29 14:38:18 +01:00
|
|
|
# from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
|
|
|
from pytorch_lightning.loggers import CSVLogger
|
2021-05-20 22:20:48 +02:00
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
|
|
|
|
|
|
|
|
|
|
|
def get_ckpt_dir(save_path, exp_name):
|
|
|
|
return os.path.join(save_path, exp_name, "ckpts")
|
|
|
|
|
|
|
|
|
|
|
|
def get_ckpt_callback(save_path, exp_name, monitor="val_loss", mode="min"):
|
|
|
|
ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
|
2024-02-29 14:38:18 +01:00
|
|
|
return ModelCheckpoint(dirpath=ckpt_dir,
|
2021-05-20 22:20:48 +02:00
|
|
|
save_top_k=1,
|
|
|
|
verbose=True,
|
|
|
|
monitor=monitor,
|
|
|
|
mode=mode,
|
2024-02-29 14:38:18 +01:00
|
|
|
# prefix=''
|
|
|
|
)
|
2021-05-20 22:20:48 +02:00
|
|
|
|
|
|
|
|
|
|
|
def get_early_stop_callback(patience=10):
|
|
|
|
return EarlyStopping(monitor='val_loss',
|
|
|
|
patience=patience,
|
|
|
|
verbose=True,
|
|
|
|
mode='min')
|
|
|
|
|
|
|
|
|
|
|
|
def get_logger(save_path, exp_name):
|
|
|
|
exp_dir = os.path.join(save_path, exp_name)
|
2024-02-29 14:38:18 +01:00
|
|
|
return CSVLogger(save_dir=exp_dir,
|
2021-05-20 22:20:48 +02:00
|
|
|
name='lightning_logs',
|
|
|
|
version="0")
|