surveilling-surveillance/detection/data/info.py
2021-05-20 13:22:04 -07:00

178 lines
5.1 KiB
Python

import yaml
import dataclasses
import pandas as pd
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union
from .version import Version
class BaseInfo:
@classmethod
def from_dict(cls, dataset_info_dict: dict) -> "DatasetInfo":
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(
**{k: v for k, v in dataset_info_dict.items() if k in field_names})
@dataclass
class ImageSourceInfo(BaseInfo):
# Required Fields
name: str = field(default_factory=str)
height: int = field(default_factory=int)
width: int = field(default_factory=int)
date: str = field(default_factory=str)
# Optional Fields
channels: Optional[list] = None
resolution: Optional[str] = field(default_factory=str)
@dataclass
class DatasetInfo(BaseInfo):
name: str = field(default_factory=str)
description: str = field(default_factory=str)
author: str = field(default_factory=str)
version: Union[str, Version] = field(default_factory=Version)
date: str = field(default_factory=str)
task: List[str] = field(default_factory=list)
class_names: List[str] = field(default_factory=list)
sources: List[ImageSourceInfo] = field(default_factory=ImageSourceInfo)
def __post_init__(self):
if self.version is not None and not isinstance(self.version, Version):
if isinstance(self.version, str):
self.version = Version(self.version)
else:
self.version = Version.from_dict(self.version)
if self.sources is not None and not all(
[isinstance(s, ImageSourceInfo) for s in self.sources]):
sources = []
for source in self.sources:
if isinstance(source, ImageSourceInfo):
pass
elif isinstance(source, dict):
source = ImageSourceInfo.from_dict(source)
else:
raise ValueError(
f"Unknown type for ImageSourceInfo: {type(source)}")
sources.append(source)
self.sources = sources
@classmethod
def load(cls, path):
with open(path, "r") as f:
yaml_dict = yaml.load(f, Loader=yaml.SafeLoader)
return cls.from_dict(yaml_dict)
def save(self, path):
with open(path, "w") as f:
yaml.dump(asdict(self), f)
def dump(self, fileobj):
yaml.dump(asdict(self), fileobj)
class DatasetInfoMixin:
def __init__(self,
info: DatasetInfo,
meta: pd.DataFrame,
split: Optional[str] = None):
self._info = info
self._meta = meta
self._split = split
self._format = None
if self._split is not None and self._split != 'all':
self._meta.query(f"split == '{self._split}'", inplace=True)
def __len__(self):
return len(self._meta)
def __repr__(self):
features = self.features
if len(features) < 5:
features_repr = "[" + ", ".join(features) + "]"
else:
features_repr = "[" + \
", ".join(features[:3] + ["...", features[-1]]) + "]"
return f"{type(self).__name__}(split: {self.split}, version: {self.version}, features[{len(features)}]: {features_repr}, samples: {self.__len__()})"
def get_split(self, split):
if split == "all":
return self
elif split in self.splits:
result = self.query(f"split == '{split}'")
result._split = split
return result
else:
raise ValueError(
f"Unknown split {split}. Split has to be one of {list(self.splits.keys())}")
def slice(self, expr):
result = deepcopy(self)
result._meta = result._meta.iloc[expr]
return result
def query(self, expr):
result = deepcopy(self)
result._meta = result._meta.query(expr)
return result
def filter(self, func):
result = deepcopy(self)
result._meta = result._meta[result._meta.apply(func, 1)].reset_index()
return result
def set_format(self, columns: Union[dict, list]):
self._format = columns
def reset_format(self):
self.set_format(None)
def value_counts(self, value):
return self._meta[value].value_counts().to_dict()
@property
def info(self):
return self._info
@property
def meta(self):
return self._meta.copy()
@property
def name(self):
return self._info.name
@property
def version(self):
return self._info.version
@property
def description(self):
return self._info.description
@property
def author(self):
return self._info.author
@property
def sources(self):
return [s.name for s in self._info.sources]
@property
def split(self):
if self._split is None:
return "all"
return self._split
@property
def splits(self):
return self.value_counts("split")
@property
def features(self):
features = list(self._meta.columns)
return features