188 lines
7.9 KiB
Python
188 lines
7.9 KiB
Python
import nni
|
|
import pickle as pkl
|
|
import json
|
|
import pytorch_lightning as pl
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
from PIL import Image
|
|
import pandas as pd
|
|
from detectron2.data import transforms as T
|
|
from detectron2.structures import Instances, Boxes
|
|
from detectron2.utils.visualizer import Visualizer
|
|
from torch.utils.data import DataLoader
|
|
from ignite.metrics import Accuracy
|
|
|
|
from models import get_model
|
|
from eval import DetectionEvaluator
|
|
from data import get_dataset
|
|
from util import constants as C
|
|
from util import get_concat_h_cut
|
|
from .logger import TFLogger
|
|
|
|
|
|
class DetectionTask(pl.LightningModule, TFLogger):
|
|
"""Standard interface for the trainer to interact with the model."""
|
|
|
|
def __init__(self, params):
|
|
super().__init__()
|
|
self.save_hyperparameters(params)
|
|
self.model = get_model(params)
|
|
self.evaluator = DetectionEvaluator()
|
|
self._plot_save_dir = params['save_dir'] + f"/{params['exp_name']}/plots"
|
|
self._output_save_dir = params['save_dir'] + f"/{params['exp_name']}/outputs"
|
|
if not os.path.exists(self._plot_save_dir):
|
|
os.mkdir(self._plot_save_dir)
|
|
if not os.path.exists(self._output_save_dir):
|
|
os.mkdir(self._output_save_dir)
|
|
|
|
def training_step(self, batch, batch_nb):
|
|
losses = self.model.forward(batch)
|
|
loss = torch.stack(list(losses.values())).mean()
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_nb):
|
|
losses = self.model.forward(batch)
|
|
loss = torch.stack(list(losses.values())).mean()
|
|
preds = self.model.infer(batch)
|
|
self.evaluator.process(batch, preds)
|
|
return loss
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
avg_loss = torch.stack(outputs).mean()
|
|
self.log("val_loss", avg_loss)
|
|
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
|
|
nni.report_intermediate_result(metrics['mAP'])
|
|
self.evaluator.reset()
|
|
self.log_dict(metrics, prog_bar=True)
|
|
|
|
def test_step(self, batch, batch_nb):
|
|
preds = self.model.infer(batch)
|
|
conf_threshold = self.hparams.get("conf_threshold", 0)
|
|
iou_threshold = self.hparams.get("iou_threshold", 0.5)
|
|
padding = self.hparams.get("padding", 10)
|
|
if self.hparams.get('visualize', False) or self.hparams.get("deploy", False):
|
|
for i, (sample, pred) in enumerate(zip(batch, preds)):
|
|
# print(i, sample, pred)
|
|
instances = pred['instances']
|
|
boxes = instances.get('pred_boxes').tensor
|
|
class_id = instances.get('pred_classes')
|
|
|
|
# Filter by scores
|
|
scores = instances.scores
|
|
keep_id_conf = scores > conf_threshold
|
|
boxes_conf = boxes[keep_id_conf]
|
|
scores_conf = scores[keep_id_conf]
|
|
class_id_conf = class_id[keep_id_conf]
|
|
if boxes_conf.size(0) == 0:
|
|
continue
|
|
|
|
# Filter by nms
|
|
keep_id_nms = torchvision.ops.nms(boxes_conf,
|
|
scores_conf,
|
|
iou_threshold)
|
|
|
|
boxes_nms = boxes_conf[keep_id_nms]
|
|
scores_nms = scores_conf[keep_id_nms]
|
|
class_id_nms = class_id_conf[keep_id_nms]
|
|
|
|
# Pad box size
|
|
boxes_nms[:, 0] -= padding
|
|
boxes_nms[:, 1] -= padding
|
|
boxes_nms[:, 2] += padding
|
|
boxes_nms[:, 3] += padding
|
|
boxes_nms = torch.clip(boxes_nms, 0, 640)
|
|
|
|
for j in range(len(scores_nms)):
|
|
# print("- ", j)
|
|
instances = Instances((640, 640))
|
|
class_id_numpy = class_id_nms.to("cpu").numpy()[j]
|
|
box_numpy = boxes_nms.to("cpu").numpy()[j]
|
|
score_numpy = scores_nms.to("cpu").numpy()[j]
|
|
|
|
instances.pred_classes = np.array([class_id_numpy])
|
|
instances.pred_boxes = Boxes(box_numpy[np.newaxis,:])
|
|
instances.scores = np.array([score_numpy])
|
|
|
|
v = Visualizer(np.transpose(sample['image'].to("cpu"), (1,2,0)),
|
|
instance_mode=1,
|
|
metadata=C.META)
|
|
out = v.draw_instance_predictions(instances)
|
|
img_box = Image.fromarray(out.get_image())
|
|
|
|
if self.hparams.get("deploy", False):
|
|
panoid = sample['panoid']
|
|
heading = sample['heading']
|
|
save_path = f".output/{panoid[:2]}/{panoid}_{heading}_{j}.jpg"
|
|
json_save_path = f".output/{panoid[:2]}/{panoid}_{heading}_{j}.json"
|
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
img_org = Image.open(sample['save_path'])
|
|
img_out = get_concat_h_cut(img_org, img_box)
|
|
img_out.save(save_path)
|
|
data = {"panoid": panoid,
|
|
"heaidng": int(heading),
|
|
"detection_id": int(j),
|
|
"class_id": int(class_id_numpy),
|
|
"box": [int(x) for x in box_numpy],
|
|
"score": float(score_numpy),
|
|
"save_path": save_path}
|
|
with open(json_save_path, 'w') as fp:
|
|
json.dump(data, fp)
|
|
else:
|
|
# print("save img", f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
|
|
img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
|
|
|
|
self.evaluator.process(batch, preds)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
|
|
nni.report_final_result(metrics['mAP'])
|
|
self.log_dict(metrics)
|
|
|
|
def configure_optimizers(self):
|
|
return [torch.optim.Adam(self.parameters(), lr=self.hparams['learning_rate'])]
|
|
|
|
def train_dataloader(self):
|
|
dataset = get_dataset('train')
|
|
return dataset.detection_dataloader(
|
|
shuffle=True,
|
|
augmentations=[
|
|
T.RandomBrightness(0.9, 1.1),
|
|
T.RandomFlip(prob=0.5),
|
|
],
|
|
batch_size=self.hparams['batch_size'],
|
|
num_workers=8)
|
|
|
|
def val_dataloader(self):
|
|
dataset = get_dataset('valid')
|
|
return dataset.detection_dataloader(
|
|
shuffle=False,
|
|
batch_size=1,
|
|
num_workers=8)
|
|
|
|
def test_dataloader(self):
|
|
if self.hparams.get('deploy', False):
|
|
dataset = get_dataset(None)#load_dataset(self.hparams['deploy_meta_path'])
|
|
# TODO: apprently it does ignore much of the loaded dataset
|
|
# can we just remove/restructure the 'load_dataset' call
|
|
df = pd.read_csv(self.hparams['deploy_meta_path'])
|
|
print(df)
|
|
df = df.query("downloaded == True")
|
|
df["image_id"] = df['save_path']
|
|
df["gsv_image_path"] = df['save_path']
|
|
df['annotations'] = "[]"
|
|
dataset._meta = df
|
|
return dataset.detection_dataloader(
|
|
shuffle=False,
|
|
batch_size=self.hparams.get("test_batch_size", 1),
|
|
num_workers=8)
|
|
else:
|
|
test_split = self.hparams.get("test_split", "valid")
|
|
dataset = get_dataset(test_split)
|
|
return dataset.detection_dataloader(
|
|
shuffle=False,
|
|
batch_size=1,
|
|
num_workers=8)
|