80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
from .info import DatasetInfoMixin
|
||
|
from . import constants as C
|
||
|
|
||
|
|
||
|
def trivial_batch_collator(batch):
|
||
|
return batch
|
||
|
|
||
|
|
||
|
class DetectionMixin:
|
||
|
def detection_dataloader(self,
|
||
|
augmentations=None,
|
||
|
is_train=True,
|
||
|
use_instance_mask=False,
|
||
|
image_path_col=None,
|
||
|
**kwargs):
|
||
|
from detectron2.data import DatasetMapper
|
||
|
if augmentations is None:
|
||
|
augmentations = []
|
||
|
mapper = DatasetMapper(is_train=is_train,
|
||
|
image_format="RGB",
|
||
|
use_instance_mask=use_instance_mask,
|
||
|
instance_mask_format="bitmask",
|
||
|
augmentations=augmentations
|
||
|
)
|
||
|
return DetectionDataset(info=self.info,
|
||
|
meta=self.meta,
|
||
|
split=self.split,
|
||
|
image_path_col=image_path_col,
|
||
|
mapper=mapper) \
|
||
|
.dataloader(**kwargs)
|
||
|
|
||
|
|
||
|
class DetectionDataset(Dataset, DatasetInfoMixin):
|
||
|
"""
|
||
|
Dataset class that provides standard Detectron2 model input format:
|
||
|
https://detectron2.readthedocs.io/en/latest/tutorials/models.html?highlight=input%20format#model-input-format
|
||
|
Notice the annotation column in the meta file need to follow Detectron2's
|
||
|
standard dataset dict format:
|
||
|
https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html#standard-dataset-dicts
|
||
|
"""
|
||
|
|
||
|
def __init__(self, info, meta, mapper, split=None, image_path_col=None):
|
||
|
if C.ANNOTATION_COLUMN not in meta.columns:
|
||
|
raise ValueError(f"[{C.ANNOTATION_COLUMN}] column not found in the meta data.")
|
||
|
|
||
|
if image_path_col is None:
|
||
|
image_path_cols = [
|
||
|
c for c in meta.columns if c.endswith("image_path")]
|
||
|
if len(image_path_cols) == 0:
|
||
|
raise ValueError(
|
||
|
"No image path column found in the meta data. Please check meta data and use `image_path_col` argument to specify the column.")
|
||
|
elif len(image_path_cols) > 1:
|
||
|
raise ValueError(
|
||
|
"Multiple image path columns found in the meta data. Please use `image_path_col` argument to specify the column.")
|
||
|
else:
|
||
|
image_path_col = image_path_cols[0]
|
||
|
|
||
|
meta = meta.rename(columns={image_path_col: "file_name"})
|
||
|
|
||
|
self.mapper = mapper
|
||
|
|
||
|
DatasetInfoMixin.__init__(self,
|
||
|
info=info,
|
||
|
meta=meta,
|
||
|
split=split)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
sample = self._meta.iloc[index].to_dict()
|
||
|
sample[C.ANNOTATION_COLUMN] = eval(sample[C.ANNOTATION_COLUMN])
|
||
|
return self.mapper(sample)
|
||
|
|
||
|
def dataloader(self, **kwargs):
|
||
|
return torch.utils.data.DataLoader(
|
||
|
self,
|
||
|
collate_fn=trivial_batch_collator,
|
||
|
**kwargs)
|