import glob import math import os import random import time from collections import OrderedDict import cv2 import numpy as np import torch from torch.utils.data import Dataset from utils.utils import xyxy2xywh class LoadImages: # for inference def __init__(self, path, img_size=(1088, 608)): if os.path.isdir(path): image_format = ['.jpg', '.jpeg', '.png', '.tif'] self.files = sorted(glob.glob('%s/*.*' % path)) self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files)) elif os.path.isfile(path): self.files = [path] self.nF = len(self.files) # number of image files self.width = img_size[0] self.height = img_size[1] self.count = 0 assert self.nF > 0, 'No images found in ' + path def __iter__(self): self.count = -1 return self def __next__(self): self.count += 1 if self.count == self.nF: raise StopIteration img_path = self.files[self.count] # Read image img0 = cv2.imread(img_path) # BGR assert img0 is not None, 'Failed to load ' + img_path # Padded resize img, _, _, _ = letterbox(img0, height=self.height, width=self.width) # Normalize RGB img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img, dtype=np.float32) img /= 255.0 # cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image return img_path, img, img0 def __getitem__(self, idx): idx = idx % self.nF img_path = self.files[idx] # Read image img0 = cv2.imread(img_path) # BGR assert img0 is not None, 'Failed to load ' + img_path # Padded resize img, _, _, _ = letterbox(img0, height=self.height, width=self.width) # Normalize RGB img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img, dtype=np.float32) img /= 255.0 return img_path, img, img0 def __len__(self): return self.nF # number of files class LoadImagesAndLabels: # for training def __init__(self, path, img_size=(1088,608), augment=False, transforms=None): with open(path, 'r') as file: self.img_files = file.readlines() self.img_files = [x.replace('\n', '') for x in self.img_files] self.img_files = list(filter(lambda x: len(x) > 0, self.img_files)) self.label_files = [x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt') for x in self.img_files] self.nF = len(self.img_files) # number of image files self.width = img_size[0] self.height = img_size[1] self.augment = augment self.transforms = transforms def __getitem__(self, files_index): img_path = self.img_files[files_index] label_path = self.label_files[files_index] return self.get_data(img_path, label_path) def get_data(self, img_path, label_path): height = self.height width = self.width img = cv2.imread(img_path) # BGR if img is None: raise ValueError('File corrupt {}'.format(img_path)) augment_hsv = True if self.augment and augment_hsv: # SV augmentation by 50% fraction = 0.50 img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) S = img_hsv[:, :, 1].astype(np.float32) V = img_hsv[:, :, 2].astype(np.float32) a = (random.random() * 2 - 1) * fraction + 1 S *= a if a > 1: np.clip(S, a_min=0, a_max=255, out=S) a = (random.random() * 2 - 1) * fraction + 1 V *= a if a > 1: np.clip(V, a_min=0, a_max=255, out=V) img_hsv[:, :, 1] = S.astype(np.uint8) img_hsv[:, :, 2] = V.astype(np.uint8) cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) h, w, _ = img.shape img, ratio, padw, padh = letterbox(img, height=height, width=width) # Load labels if os.path.isfile(label_path): labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6) # Normalized xywh to pixel xyxy format labels = labels0.copy() labels[:, 2] = ratio * w * (labels0[:, 2] - labels0[:, 4] / 2) + padw labels[:, 3] = ratio * h * (labels0[:, 3] - labels0[:, 5] / 2) + padh labels[:, 4] = ratio * w * (labels0[:, 2] + labels0[:, 4] / 2) + padw labels[:, 5] = ratio * h * (labels0[:, 3] + labels0[:, 5] / 2) + padh else: labels = np.array([]) # Augment image and labels if self.augment: img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.50, 1.20)) plotFlag = False if plotFlag: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt plt.figure(figsize=(50, 50)) plt.imshow(img[:, :, ::-1]) plt.plot(labels[:, [1, 3, 3, 1, 1]].T, labels[:, [2, 2, 4, 4, 2]].T, '.-') plt.axis('off') plt.savefig('test.jpg') time.sleep(10) nL = len(labels) if nL > 0: # convert xyxy to xywh labels[:, 2:6] = xyxy2xywh(labels[:, 2:6].copy()) #/ height labels[:, 2] /= width labels[:, 3] /= height labels[:, 4] /= width labels[:, 5] /= height if self.augment: # random left-right flip lr_flip = True if lr_flip & (random.random() > 0.5): img = np.fliplr(img) if nL > 0: labels[:, 2] = 1 - labels[:, 2] img = np.ascontiguousarray(img[ :, :, ::-1]) # BGR to RGB if self.transforms is not None: img = self.transforms(img) return img, labels, img_path, (h, w) def __len__(self): return self.nF # number of batches def letterbox(img, height=608, width=1088, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded rectangular shape = img.shape[:2] # shape = [height, width] ratio = min(float(height)/shape[0], float(width)/shape[1]) new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # new_shape = [width, height] dw = (width - new_shape[0]) / 2 # width padding dh = (height - new_shape[1]) / 2 # height padding top, bottom = round(dh - 0.1), round(dh + 0.1) left, right = round(dw - 0.1), round(dw + 0.1) img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded rectangular return img, ratio, dw, dh def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), borderValue=(127.5, 127.5, 127.5)): # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 border = 0 # width of added border (optional) height = img.shape[0] width = img.shape[1] # Rotation and Scale R = np.eye(3) a = random.random() * (degrees[1] - degrees[0]) + degrees[0] # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations s = random.random() * (scale[1] - scale[0]) + scale[0] R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) # Translation T = np.eye(3) T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border # x translation (pixels) T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border # y translation (pixels) # Shear S = np.eye(3) S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg) S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg) M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!! imw = cv2.warpPerspective(img, M, dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=borderValue) # BGR order borderValue # Return warped points also if targets is not None: if len(targets) > 0: n = targets.shape[0] points = targets[:, 2:6].copy() area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1]) # warp points xy = np.ones((n * 4, 3)) xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 xy = (xy @ M.T)[:, :2].reshape(n, 8) # create new boxes x = xy[:, [0, 2, 4, 6]] y = xy[:, [1, 3, 5, 7]] xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T # apply angle-based reduction radians = a * math.pi / 180 reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 x = (xy[:, 2] + xy[:, 0]) / 2 y = (xy[:, 3] + xy[:, 1]) / 2 w = (xy[:, 2] - xy[:, 0]) * reduction h = (xy[:, 3] - xy[:, 1]) * reduction xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T # reject warped points outside of image np.clip(xy[:, 0], 0, width, out=xy[:, 0]) np.clip(xy[:, 2], 0, width, out=xy[:, 2]) np.clip(xy[:, 1], 0, height, out=xy[:, 1]) np.clip(xy[:, 3], 0, height, out=xy[:, 3]) w = xy[:, 2] - xy[:, 0] h = xy[:, 3] - xy[:, 1] area = w * h ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10) targets = targets[i] targets[:, 2:6] = xy[i] return imw, targets, M else: return imw def collate_fn(batch): imgs, labels, paths, sizes = zip(*batch) batch_size = len(labels) imgs = torch.stack(imgs, 0) max_box_len = max([l.shape[0] for l in labels]) labels = [torch.from_numpy(l) for l in labels] filled_labels = torch.zeros(batch_size, max_box_len, 6) labels_len = torch.zeros(batch_size) for i in range(batch_size): isize = labels[i].shape[0] if len(labels[i])>0: filled_labels[i, :isize, :] = labels[i] labels_len[i] = isize return imgs, filled_labels, paths, sizes, labels_len.unsqueeze(1) class JointDataset(LoadImagesAndLabels): # for training def __init__(self, paths, img_size=(1088,608), augment=False, transforms=None): dataset_names = paths.keys() self.img_files = OrderedDict() self.label_files = OrderedDict() self.tid_num = OrderedDict() self.tid_start_index = OrderedDict() for ds, path in paths.items(): with open(path, 'r') as file: self.img_files[ds] = file.readlines() self.img_files[ds] = [x.strip() for x in self.img_files[ds]] self.img_files[ds] = list(filter(lambda x: len(x) > 0, self.img_files[ds])) self.label_files[ds] = [x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt') for x in self.img_files[ds]] for ds, label_paths in self.label_files.items(): max_index = -1 for lp in label_paths: lb = np.loadtxt(lp) if len(lb) < 1: continue if len(lb.shape) < 2: img_max = lb[1] else: img_max = np.max(lb[:,1]) if img_max >max_index: max_index = img_max self.tid_num[ds] = max_index + 1 last_index = 0 for i, (k, v) in enumerate(self.tid_num.items()): self.tid_start_index[k] = last_index last_index += v self.nID = int(last_index+1) self.nds = [len(x) for x in self.img_files.values()] self.cds = [sum(self.nds[:i]) for i in range(len(self.nds))] self.nF = sum(self.nds) self.width = img_size[0] self.height = img_size[1] self.augment = augment self.transforms = transforms print('='*80) print('dataset summary') print(self.tid_num) print('total # identities:', self.nID) print('start index') print(self.tid_start_index) print('='*80) def __getitem__(self, files_index): for i, c in enumerate(self.cds): if files_index >= c: ds = list(self.label_files.keys())[i] start_index = c img_path = self.img_files[ds][files_index - start_index] label_path = self.label_files[ds][files_index - start_index] imgs, labels, img_path, (h, w) = self.get_data(img_path, label_path) for i, _ in enumerate(labels): if labels[i,1] > -1: labels[i,1] += self.tid_start_index[ds] return imgs, labels, img_path, (h, w)