surveilling-surveillance/detection/lightning/detection.py

182 lines
7.5 KiB
Python
Raw Normal View History

2021-05-20 22:20:48 +02:00
import nni
import pickle as pkl
import json
import pytorch_lightning as pl
import os
import numpy as np
import torch
import torchvision
from PIL import Image
import pandas as pd
from detectron2.data import transforms as T
from detectron2.structures import Instances, Boxes
from detectron2.utils.visualizer import Visualizer
from torch.utils.data import DataLoader
from ignite.metrics import Accuracy
from models import get_model
from eval import DetectionEvaluator
from data import get_dataset
from util import constants as C
from util import get_concat_h_cut
from .logger import TFLogger
class DetectionTask(pl.LightningModule, TFLogger):
"""Standard interface for the trainer to interact with the model."""
def __init__(self, params):
super().__init__()
self.save_hyperparameters(params)
self.model = get_model(params)
self.evaluator = DetectionEvaluator()
self._plot_save_dir = params['save_dir'] + f"/{params['exp_name']}/plots"
self._output_save_dir = params['save_dir'] + f"/{params['exp_name']}/outputs"
if not os.path.exists(self._plot_save_dir):
os.mkdir(self._plot_save_dir)
if not os.path.exists(self._output_save_dir):
os.mkdir(self._output_save_dir)
2021-05-20 22:20:48 +02:00
def training_step(self, batch, batch_nb):
losses = self.model.forward(batch)
loss = torch.stack(list(losses.values())).mean()
return loss
def validation_step(self, batch, batch_nb):
losses = self.model.forward(batch)
loss = torch.stack(list(losses.values())).mean()
preds = self.model.infer(batch)
self.evaluator.process(batch, preds)
return loss
def validation_epoch_end(self, outputs):
avg_loss = torch.stack(outputs).mean()
self.log("val_loss", avg_loss)
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
2021-05-20 22:20:48 +02:00
nni.report_intermediate_result(metrics['mAP'])
self.evaluator.reset()
self.log_dict(metrics, prog_bar=True)
def test_step(self, batch, batch_nb):
preds = self.model.infer(batch)
conf_threshold = self.hparams.get("conf_threshold", 0)
iou_threshold = self.hparams.get("iou_threshold", 0.5)
padding = self.hparams.get("padding", 10)
if self.hparams.get('visualize', False) or self.hparams.get("deploy", False):
for i, (sample, pred) in enumerate(zip(batch, preds)):
instances = pred['instances']
boxes = instances.get('pred_boxes').tensor
class_id = instances.get('pred_classes')
# Filter by scores
scores = instances.scores
keep_id_conf = scores > conf_threshold
boxes_conf = boxes[keep_id_conf]
scores_conf = scores[keep_id_conf]
class_id_conf = class_id[keep_id_conf]
if boxes_conf.size(0) == 0:
continue
# Filter by nms
keep_id_nms = torchvision.ops.nms(boxes_conf,
scores_conf,
iou_threshold)
boxes_nms = boxes_conf[keep_id_nms]
scores_nms = scores_conf[keep_id_nms]
class_id_nms = class_id_conf[keep_id_nms]
# Pad box size
boxes_nms[:, 0] -= padding
boxes_nms[:, 1] -= padding
boxes_nms[:, 2] += padding
boxes_nms[:, 3] += padding
boxes_nms = torch.clip(boxes_nms, 0, 640)
for j in range(len(scores_nms)):
instances = Instances((640, 640))
class_id_numpy = class_id_nms.to("cpu").numpy()[j]
box_numpy = boxes_nms.to("cpu").numpy()[j]
score_numpy = scores_nms.to("cpu").numpy()[j]
instances.pred_classes = np.array([class_id_numpy])
instances.pred_boxes = Boxes(box_numpy[np.newaxis,:])
instances.scores = np.array([score_numpy])
v = Visualizer(np.transpose(sample['image'].to("cpu"), (1,2,0)),
instance_mode=1,
metadata=C.META)
out = v.draw_instance_predictions(instances)
img_box = Image.fromarray(out.get_image())
if self.hparams.get("deploy", False):
panoid = sample['panoid']
heading = sample['heading']
save_path = f".output/{panoid[:2]}/{panoid}_{heading}_{j}.jpg"
json_save_path = f".output/{panoid[:2]}/{panoid}_{heading}_{j}.json"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
img_org = Image.open(sample['save_path'])
img_out = get_concat_h_cut(img_org, img_box)
img_out.save(save_path)
data = {"panoid": panoid,
"heaidng": int(heading),
"detection_id": int(j),
"class_id": int(class_id_numpy),
"box": [int(x) for x in box_numpy],
"score": float(score_numpy),
"save_path": save_path}
with open(json_save_path, 'w') as fp:
json.dump(data, fp)
else:
img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
2021-05-20 22:20:48 +02:00
self.evaluator.process(batch, preds)
def test_epoch_end(self, outputs):
metrics = self.evaluator.evaluate(save_dir=self._plot_save_dir)
2021-05-20 22:20:48 +02:00
nni.report_final_result(metrics['mAP'])
self.log_dict(metrics)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=self.hparams['learning_rate'])]
def train_dataloader(self):
dataset = get_dataset('train')
return dataset.detection_dataloader(
shuffle=True,
augmentations=[
T.RandomBrightness(0.9, 1.1),
T.RandomFlip(prob=0.5),
],
batch_size=self.hparams['batch_size'],
num_workers=8)
def val_dataloader(self):
dataset = get_dataset('valid')
return dataset.detection_dataloader(
shuffle=False,
batch_size=1,
num_workers=8)
def test_dataloader(self):
if self.hparams.get('deploy', False):
dataset = load_dataset(self.hparams['dataset_name'])
df = pd.read_csv(self.hparams['deploy_meta_path']).query("downloaded == True")
df["image_id"] = df['save_path']
df["gsv_image_path"] = df['save_path']
df['annotations'] = "[]"
dataset._meta = df
return dataset.detection_dataloader(
shuffle=False,
batch_size=self.hparams.get("test_batch_size", 1),
num_workers=8)
else:
test_split = self.hparams.get("test_split", "valid")
dataset = get_dataset(test_split)
return dataset.detection_dataloader(
shuffle=False,
batch_size=1,
num_workers=8)