Enable running train.py from project root, fix nested output dirs
This commit is contained in:
parent
ebf52bfa07
commit
6bfb80c88e
3 changed files with 21 additions and 7 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
from pathlib import Path
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -8,6 +9,13 @@ from . import constants as C
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(split="train"):
|
def get_dataset(split="train"):
|
||||||
meta = pd.read_csv("../data/meta.csv")
|
data_dir = Path(__file__).parent.parent.parent / "data"
|
||||||
info = DatasetInfo.load("../data/info.yaml")
|
# here = os.path.dirname(__path__)
|
||||||
|
# data_dir = here + '/../../data/'
|
||||||
|
|
||||||
|
meta = pd.read_csv(data_dir / "non-empty-meta.csv")
|
||||||
|
info = DatasetInfo.load(data_dir / "info.yaml")
|
||||||
|
|
||||||
|
# meta['image_path'] = [f'data/image/{panoid}_{heading}.jpg' for panoid, heading in zip(meta['panoid'], meta['heading'])]
|
||||||
|
|
||||||
return BaseDataset(info, meta)[split]
|
return BaseDataset(info, meta)[split]
|
||||||
|
|
|
@ -108,7 +108,7 @@ class DetectionEvaluator(DatasetEvaluator):
|
||||||
format=detection.BBFormat.XYX2Y2)
|
format=detection.BBFormat.XYX2Y2)
|
||||||
self._bbox.addBoundingBox(bb)
|
self._bbox.addBoundingBox(bb)
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self, save_dir):
|
||||||
results = self._evaluator.GetPascalVOCMetrics(self._bbox, self._iou_thresh)
|
results = self._evaluator.GetPascalVOCMetrics(self._bbox, self._iou_thresh)
|
||||||
if isinstance(results, dict):
|
if isinstance(results, dict):
|
||||||
results = [results]
|
results = [results]
|
||||||
|
@ -118,7 +118,7 @@ class DetectionEvaluator(DatasetEvaluator):
|
||||||
metrics[f'AP_{result["class"]}'] = result['AP']
|
metrics[f'AP_{result["class"]}'] = result['AP']
|
||||||
APs.append(result['AP'])
|
APs.append(result['AP'])
|
||||||
metrics['mAP'] = np.nanmean(APs)
|
metrics['mAP'] = np.nanmean(APs)
|
||||||
self._evaluator.PlotPrecisionRecallCurve(self._bbox, savePath="./plots/", showGraphic=False)
|
self._evaluator.PlotPrecisionRecallCurve(self._bbox, savePath=save_dir, showGraphic=False)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
class DatasetEvaluators(DatasetEvaluator):
|
class DatasetEvaluators(DatasetEvaluator):
|
||||||
|
|
|
@ -30,6 +30,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
|
||||||
self.save_hyperparameters(params)
|
self.save_hyperparameters(params)
|
||||||
self.model = get_model(params)
|
self.model = get_model(params)
|
||||||
self.evaluator = DetectionEvaluator()
|
self.evaluator = DetectionEvaluator()
|
||||||
|
self._plot_save_dir = params['save_dir'] + f"/{params['exp_name']}/plots"
|
||||||
|
self._output_save_dir = params['save_dir'] + f"/{params['exp_name']}/outputs"
|
||||||
|
if not os.path.exists(self._plot_save_dir):
|
||||||
|
os.mkdir(self._plot_save_dir)
|
||||||
|
if not os.path.exists(self._output_save_dir):
|
||||||
|
os.mkdir(self._output_save_dir)
|
||||||
|
|
||||||
def training_step(self, batch, batch_nb):
|
def training_step(self, batch, batch_nb):
|
||||||
losses = self.model.forward(batch)
|
losses = self.model.forward(batch)
|
||||||
|
@ -46,7 +52,7 @@ class DetectionTask(pl.LightningModule, TFLogger):
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
avg_loss = torch.stack(outputs).mean()
|
avg_loss = torch.stack(outputs).mean()
|
||||||
self.log("val_loss", avg_loss)
|
self.log("val_loss", avg_loss)
|
||||||
metrics = self.evaluator.evaluate()
|
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
|
||||||
nni.report_intermediate_result(metrics['mAP'])
|
nni.report_intermediate_result(metrics['mAP'])
|
||||||
self.evaluator.reset()
|
self.evaluator.reset()
|
||||||
self.log_dict(metrics, prog_bar=True)
|
self.log_dict(metrics, prog_bar=True)
|
||||||
|
@ -124,12 +130,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
|
||||||
with open(json_save_path, 'w') as fp:
|
with open(json_save_path, 'w') as fp:
|
||||||
json.dump(data, fp)
|
json.dump(data, fp)
|
||||||
else:
|
else:
|
||||||
img_box.save(f"outputs/{batch_nb}_{i}.jpg")
|
img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
|
||||||
|
|
||||||
self.evaluator.process(batch, preds)
|
self.evaluator.process(batch, preds)
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
metrics = self.evaluator.evaluate()
|
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
|
||||||
nni.report_final_result(metrics['mAP'])
|
nni.report_final_result(metrics['mAP'])
|
||||||
self.log_dict(metrics)
|
self.log_dict(metrics)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue