Enable running train.py from project root, fix nested output dirs
This commit is contained in:
parent
ebf52bfa07
commit
6bfb80c88e
3 changed files with 21 additions and 7 deletions
|
@ -1,3 +1,4 @@
|
|||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
|
@ -8,6 +9,13 @@ from . import constants as C
|
|||
|
||||
|
||||
def get_dataset(split="train"):
|
||||
meta = pd.read_csv("../data/meta.csv")
|
||||
info = DatasetInfo.load("../data/info.yaml")
|
||||
data_dir = Path(__file__).parent.parent.parent / "data"
|
||||
# here = os.path.dirname(__path__)
|
||||
# data_dir = here + '/../../data/'
|
||||
|
||||
meta = pd.read_csv(data_dir / "non-empty-meta.csv")
|
||||
info = DatasetInfo.load(data_dir / "info.yaml")
|
||||
|
||||
# meta['image_path'] = [f'data/image/{panoid}_{heading}.jpg' for panoid, heading in zip(meta['panoid'], meta['heading'])]
|
||||
|
||||
return BaseDataset(info, meta)[split]
|
||||
|
|
|
@ -108,7 +108,7 @@ class DetectionEvaluator(DatasetEvaluator):
|
|||
format=detection.BBFormat.XYX2Y2)
|
||||
self._bbox.addBoundingBox(bb)
|
||||
|
||||
def evaluate(self):
|
||||
def evaluate(self, save_dir):
|
||||
results = self._evaluator.GetPascalVOCMetrics(self._bbox, self._iou_thresh)
|
||||
if isinstance(results, dict):
|
||||
results = [results]
|
||||
|
@ -118,7 +118,7 @@ class DetectionEvaluator(DatasetEvaluator):
|
|||
metrics[f'AP_{result["class"]}'] = result['AP']
|
||||
APs.append(result['AP'])
|
||||
metrics['mAP'] = np.nanmean(APs)
|
||||
self._evaluator.PlotPrecisionRecallCurve(self._bbox, savePath="./plots/", showGraphic=False)
|
||||
self._evaluator.PlotPrecisionRecallCurve(self._bbox, savePath=save_dir, showGraphic=False)
|
||||
return metrics
|
||||
|
||||
class DatasetEvaluators(DatasetEvaluator):
|
||||
|
|
|
@ -30,6 +30,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
|
|||
self.save_hyperparameters(params)
|
||||
self.model = get_model(params)
|
||||
self.evaluator = DetectionEvaluator()
|
||||
self._plot_save_dir = params['save_dir'] + f"/{params['exp_name']}/plots"
|
||||
self._output_save_dir = params['save_dir'] + f"/{params['exp_name']}/outputs"
|
||||
if not os.path.exists(self._plot_save_dir):
|
||||
os.mkdir(self._plot_save_dir)
|
||||
if not os.path.exists(self._output_save_dir):
|
||||
os.mkdir(self._output_save_dir)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
losses = self.model.forward(batch)
|
||||
|
@ -46,7 +52,7 @@ class DetectionTask(pl.LightningModule, TFLogger):
|
|||
def validation_epoch_end(self, outputs):
|
||||
avg_loss = torch.stack(outputs).mean()
|
||||
self.log("val_loss", avg_loss)
|
||||
metrics = self.evaluator.evaluate()
|
||||
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
|
||||
nni.report_intermediate_result(metrics['mAP'])
|
||||
self.evaluator.reset()
|
||||
self.log_dict(metrics, prog_bar=True)
|
||||
|
@ -124,12 +130,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
|
|||
with open(json_save_path, 'w') as fp:
|
||||
json.dump(data, fp)
|
||||
else:
|
||||
img_box.save(f"outputs/{batch_nb}_{i}.jpg")
|
||||
img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
|
||||
|
||||
self.evaluator.process(batch, preds)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
metrics = self.evaluator.evaluate()
|
||||
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
|
||||
nni.report_final_result(metrics['mAP'])
|
||||
self.log_dict(metrics)
|
||||
|
||||
|
|
Loading…
Reference in a new issue