2021-05-20 22:20:48 +02:00
|
|
|
import os
|
|
|
|
import fire
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
|
|
|
|
from util.nni import run_nni
|
|
|
|
from util import init_exp_folder, Args
|
|
|
|
from util import constants as C
|
|
|
|
from lightning import (get_task,
|
|
|
|
load_task,
|
|
|
|
get_ckpt_callback,
|
|
|
|
get_early_stop_callback,
|
|
|
|
get_logger)
|
|
|
|
|
|
|
|
|
|
|
|
def train(save_dir=C.SANDBOX_PATH,
|
|
|
|
tb_path=C.TB_PATH,
|
|
|
|
exp_name="DemoExperiment",
|
|
|
|
model="FasterRCNN",
|
|
|
|
task='detection',
|
|
|
|
gpus=1,
|
|
|
|
pretrained=True,
|
|
|
|
batch_size=8,
|
2024-02-29 14:38:18 +01:00
|
|
|
accelerator="gpu",
|
|
|
|
strategy="ddp",
|
2021-05-20 22:20:48 +02:00
|
|
|
gradient_clip_val=0.5,
|
|
|
|
max_epochs=100,
|
|
|
|
learning_rate=1e-5,
|
|
|
|
patience=30,
|
|
|
|
limit_train_batches=1.0,
|
|
|
|
limit_val_batches=1.0,
|
|
|
|
limit_test_batches=1.0,
|
|
|
|
weights_summary=None,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Run the training experiment.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
save_dir: Path to save the checkpoints and logs
|
|
|
|
exp_name: Name of the experiment
|
2024-02-29 14:38:18 +01:00
|
|
|
model: Model name ("mask_rcnn","faster_rcnn","retinanet","rpn","fast_rcnn", see 'detection/models/detection/detectron.py')
|
2021-05-20 22:20:48 +02:00
|
|
|
gpus: int. (ie: 2 gpus)
|
|
|
|
OR list to specify which GPUs [0, 1] OR '0,1'
|
|
|
|
OR '-1' / -1 to use all available gpus
|
|
|
|
pretrained: Whether or not to use the pretrained model
|
|
|
|
num_classes: Number of classes
|
2024-02-29 14:38:18 +01:00
|
|
|
accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”)
|
|
|
|
strategy: Supports different training strategies with aliases as well custom strategies (e.g. "ddp")
|
2021-05-20 22:20:48 +02:00
|
|
|
gradient_clip_val: Clip value of gradient norm
|
|
|
|
limit_train_batches: Proportion of training data to use
|
|
|
|
max_epochs: Max number of epochs
|
|
|
|
patience: number of epochs with no improvement after
|
|
|
|
which training will be stopped.
|
|
|
|
tb_path: Path to global tb folder
|
|
|
|
loss_fn: Loss function to use
|
|
|
|
weights_summary: Prints a summary of the weights when training begins.
|
|
|
|
|
|
|
|
Returns: None
|
|
|
|
|
|
|
|
"""
|
|
|
|
num_classes = 2
|
|
|
|
dataset_name = "camera-detection-new"
|
|
|
|
|
|
|
|
args = Args(locals())
|
|
|
|
init_exp_folder(args)
|
|
|
|
task = get_task(args)
|
|
|
|
trainer = Trainer(gpus=gpus,
|
|
|
|
accelerator=accelerator,
|
2024-02-29 14:38:18 +01:00
|
|
|
strategy=strategy,
|
2021-05-20 22:20:48 +02:00
|
|
|
logger=get_logger(save_dir, exp_name),
|
|
|
|
callbacks=[get_early_stop_callback(patience),
|
|
|
|
get_ckpt_callback(save_dir, exp_name, monitor="mAP", mode="max")],
|
2024-02-29 14:38:18 +01:00
|
|
|
default_root_dir=os.path.join(save_dir, exp_name),
|
2021-05-20 22:20:48 +02:00
|
|
|
gradient_clip_val=gradient_clip_val,
|
|
|
|
limit_train_batches=limit_train_batches,
|
|
|
|
limit_val_batches=limit_val_batches,
|
|
|
|
limit_test_batches=limit_test_batches,
|
2024-02-29 14:38:18 +01:00
|
|
|
# weights_summary=weights_summary,
|
2021-05-20 22:20:48 +02:00
|
|
|
max_epochs=max_epochs)
|
|
|
|
trainer.fit(task)
|
|
|
|
return save_dir, exp_name
|
|
|
|
|
|
|
|
|
|
|
|
def test(ckpt_path,
|
|
|
|
visualize=False,
|
|
|
|
deploy=False,
|
|
|
|
limit_test_batches=1.0,
|
|
|
|
gpus=1,
|
|
|
|
deploy_meta_path="/home/haosheng/dataset/camera/deployment/16cityp1.csv",
|
|
|
|
test_batch_size=1,
|
|
|
|
**kwargs):
|
|
|
|
"""
|
|
|
|
Run the testing experiment.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ckpt_path: Path for the experiment to load
|
|
|
|
gpus: int. (ie: 2 gpus)
|
|
|
|
OR list to specify which GPUs [0, 1] OR '0,1'
|
|
|
|
OR '-1' / -1 to use all available gpus
|
|
|
|
Returns: None
|
|
|
|
|
|
|
|
"""
|
|
|
|
task = load_task(ckpt_path,
|
|
|
|
visualize=visualize,
|
|
|
|
deploy=deploy,
|
|
|
|
deploy_meta_path=deploy_meta_path,
|
|
|
|
test_batch_size=test_batch_size,
|
|
|
|
**kwargs)
|
|
|
|
trainer = Trainer(gpus=gpus,
|
|
|
|
limit_test_batches=limit_test_batches)
|
|
|
|
trainer.test(task)
|
|
|
|
|
|
|
|
|
|
|
|
def nni():
|
|
|
|
run_nni(train, test)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
fire.Fire()
|