From a133096ddf7fd81485775298e149489915d755a1 Mon Sep 17 00:00:00 2001 From: rromb Date: Thu, 9 Jun 2022 10:56:41 +0200 Subject: [PATCH] coco --- ldm/data/coco.py | 248 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 ldm/data/coco.py diff --git a/ldm/data/coco.py b/ldm/data/coco.py new file mode 100644 index 0000000..7654069 --- /dev/null +++ b/ldm/data/coco.py @@ -0,0 +1,248 @@ +import os +import json +import albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset +from abc import abstractmethod + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + assert crop_type in [None, 'random', 'center'] + self.crop_type = crop_type + self.use_segmenation = use_segmentation + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json", + f"captions_val{self.year()}.json"] + # TODO currently hardcoded paths, would be better to follow logic in + # cocstuff pixelmaps + if self.use_segmenation: + if self.stuffthing: + self.segmentation_prefix = ( + f"data/cocostuffthings/val{self.year()}" if + data_json.endswith(f"captions_val{self.year()}.json") else + f"data/cocostuffthings/train{self.year()}") + else: + self.segmentation_prefix = ( + f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if + data_json.endswith(f"captions_val{self.year()}.json") else + f"data/coco/annotations/stuff_train{self.year()}_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + if self.use_segmenation: + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"]) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + # default option for train is random crop + if self.crop_type in [None, 'random']: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + @abstractmethod + def year(self): + raise NotImplementedError() + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path=None): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if segmentation_path: + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + + image, segmentation = processed["image"], processed["segmentation"] + else: + image = self.preprocessor(image=image,)['image'] + + image = (image / 127.5 - 1.0).astype(np.float32) + if segmentation_path: + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + else: + return image + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + if self.use_segmenation: + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + else: + image = self.preprocess_image(img_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + #"caption": [str(caption[0])], + "caption": caption, + "img_path": img_path, + "filename_": img_path.split(os.sep)[-1] + } + if self.use_segmenation: + example.update({"seg_path": seg_path, 'segmentation': segmentation}) + return example + + +class CocoImagesAndCaptionsTrain2017(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + def year(self): + return '2017' + + +class CocoImagesAndCaptionsValidation2017(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" + + def year(self): + return '2017' + + + +class CocoImagesAndCaptionsTrain2014(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'): + super().__init__(size=size, + dataroot="data/coco/train2014", + datajson="data/coco/annotations2014/annotations/captions_train2014.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + use_segmentation=False, + crop_type=crop_type) + + def get_split(self): + return "train" + + def year(self): + return '2014' + +class CocoImagesAndCaptionsValidation2014(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None,crop_type='center',**kwargs): + super().__init__(size=size, + dataroot="data/coco/val2014", + datajson="data/coco/annotations2014/annotations/captions_val2014.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files, + use_segmentation=False, + crop_type=crop_type) + + def get_split(self): + return "validation" + + def year(self): + return '2014' + +if __name__ == '__main__': + d2 = CocoImagesAndCaptionsValidation2014(size=256) + print("construced val set.") + print(f"length of train split: {len(d2)}") + + ex2 = d2[0] + # ex3 = d3[0] + # print(ex1["image"].shape) + print(ex2["image"].shape) + # print(ex3["image"].shape) + # print(ex1["segmentation"].shape) + print(ex2["caption"].__class__.__name__)