Enable running train.py from project root, fix nested output dirs

This commit is contained in:
Ruben van de Ven 2024-02-29 14:39:11 +01:00
parent ebf52bfa07
commit 6bfb80c88e
3 changed files with 21 additions and 7 deletions

View file

@ -1,3 +1,4 @@
from pathlib import Path
import pandas as pd import pandas as pd
import os import os
@ -8,6 +9,13 @@ from . import constants as C
def get_dataset(split="train"): def get_dataset(split="train"):
meta = pd.read_csv("../data/meta.csv") data_dir = Path(__file__).parent.parent.parent / "data"
info = DatasetInfo.load("../data/info.yaml") # here = os.path.dirname(__path__)
# data_dir = here + '/../../data/'
meta = pd.read_csv(data_dir / "non-empty-meta.csv")
info = DatasetInfo.load(data_dir / "info.yaml")
# meta['image_path'] = [f'data/image/{panoid}_{heading}.jpg' for panoid, heading in zip(meta['panoid'], meta['heading'])]
return BaseDataset(info, meta)[split] return BaseDataset(info, meta)[split]

View file

@ -108,7 +108,7 @@ class DetectionEvaluator(DatasetEvaluator):
format=detection.BBFormat.XYX2Y2) format=detection.BBFormat.XYX2Y2)
self._bbox.addBoundingBox(bb) self._bbox.addBoundingBox(bb)
def evaluate(self): def evaluate(self, save_dir):
results = self._evaluator.GetPascalVOCMetrics(self._bbox, self._iou_thresh) results = self._evaluator.GetPascalVOCMetrics(self._bbox, self._iou_thresh)
if isinstance(results, dict): if isinstance(results, dict):
results = [results] results = [results]
@ -118,7 +118,7 @@ class DetectionEvaluator(DatasetEvaluator):
metrics[f'AP_{result["class"]}'] = result['AP'] metrics[f'AP_{result["class"]}'] = result['AP']
APs.append(result['AP']) APs.append(result['AP'])
metrics['mAP'] = np.nanmean(APs) metrics['mAP'] = np.nanmean(APs)
self._evaluator.PlotPrecisionRecallCurve(self._bbox, savePath="./plots/", showGraphic=False) self._evaluator.PlotPrecisionRecallCurve(self._bbox, savePath=save_dir, showGraphic=False)
return metrics return metrics
class DatasetEvaluators(DatasetEvaluator): class DatasetEvaluators(DatasetEvaluator):

View file

@ -30,6 +30,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
self.save_hyperparameters(params) self.save_hyperparameters(params)
self.model = get_model(params) self.model = get_model(params)
self.evaluator = DetectionEvaluator() self.evaluator = DetectionEvaluator()
self._plot_save_dir = params['save_dir'] + f"/{params['exp_name']}/plots"
self._output_save_dir = params['save_dir'] + f"/{params['exp_name']}/outputs"
if not os.path.exists(self._plot_save_dir):
os.mkdir(self._plot_save_dir)
if not os.path.exists(self._output_save_dir):
os.mkdir(self._output_save_dir)
def training_step(self, batch, batch_nb): def training_step(self, batch, batch_nb):
losses = self.model.forward(batch) losses = self.model.forward(batch)
@ -46,7 +52,7 @@ class DetectionTask(pl.LightningModule, TFLogger):
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
avg_loss = torch.stack(outputs).mean() avg_loss = torch.stack(outputs).mean()
self.log("val_loss", avg_loss) self.log("val_loss", avg_loss)
metrics = self.evaluator.evaluate() metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
nni.report_intermediate_result(metrics['mAP']) nni.report_intermediate_result(metrics['mAP'])
self.evaluator.reset() self.evaluator.reset()
self.log_dict(metrics, prog_bar=True) self.log_dict(metrics, prog_bar=True)
@ -124,12 +130,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
with open(json_save_path, 'w') as fp: with open(json_save_path, 'w') as fp:
json.dump(data, fp) json.dump(data, fp)
else: else:
img_box.save(f"outputs/{batch_nb}_{i}.jpg") img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
self.evaluator.process(batch, preds) self.evaluator.process(batch, preds)
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
metrics = self.evaluator.evaluate() metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
nni.report_final_result(metrics['mAP']) nni.report_final_result(metrics['mAP'])
self.log_dict(metrics) self.log_dict(metrics)