surveilling-surveillance/detection/data/detection.py

80 lines
3.1 KiB
Python
Raw Normal View History

2021-05-20 22:20:48 +02:00
import torch
from torch.utils.data import Dataset
from .info import DatasetInfoMixin
from . import constants as C
def trivial_batch_collator(batch):
return batch
class DetectionMixin:
def detection_dataloader(self,
augmentations=None,
is_train=True,
use_instance_mask=False,
image_path_col=None,
**kwargs):
from detectron2.data import DatasetMapper
if augmentations is None:
augmentations = []
mapper = DatasetMapper(is_train=is_train,
image_format="RGB",
use_instance_mask=use_instance_mask,
instance_mask_format="bitmask",
augmentations=augmentations
)
return DetectionDataset(info=self.info,
meta=self.meta,
split=self.split,
image_path_col=image_path_col,
mapper=mapper) \
.dataloader(**kwargs)
class DetectionDataset(Dataset, DatasetInfoMixin):
"""
Dataset class that provides standard Detectron2 model input format:
https://detectron2.readthedocs.io/en/latest/tutorials/models.html?highlight=input%20format#model-input-format
Notice the annotation column in the meta file need to follow Detectron2's
standard dataset dict format:
https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html#standard-dataset-dicts
"""
def __init__(self, info, meta, mapper, split=None, image_path_col=None):
if C.ANNOTATION_COLUMN not in meta.columns:
raise ValueError(f"[{C.ANNOTATION_COLUMN}] column not found in the meta data.")
if image_path_col is None:
image_path_cols = [
c for c in meta.columns if c.endswith("image_path")]
if len(image_path_cols) == 0:
raise ValueError(
"No image path column found in the meta data. Please check meta data and use `image_path_col` argument to specify the column.")
elif len(image_path_cols) > 1:
raise ValueError(
"Multiple image path columns found in the meta data. Please use `image_path_col` argument to specify the column.")
else:
image_path_col = image_path_cols[0]
meta = meta.rename(columns={image_path_col: "file_name"})
self.mapper = mapper
DatasetInfoMixin.__init__(self,
info=info,
meta=meta,
split=split)
def __getitem__(self, index):
sample = self._meta.iloc[index].to_dict()
sample[C.ANNOTATION_COLUMN] = eval(sample[C.ANNOTATION_COLUMN])
return self.mapper(sample)
def dataloader(self, **kwargs):
return torch.utils.data.DataLoader(
self,
collate_fn=trivial_batch_collator,
**kwargs)