diff --git a/detection/data/__init__.py b/detection/data/__init__.py index c8349fc..ac935fe 100644 --- a/detection/data/__init__.py +++ b/detection/data/__init__.py @@ -18,4 +18,9 @@ def get_dataset(split="train"): # meta['image_path'] = [f'data/image/{panoid}_{heading}.jpg' for panoid, heading in zip(meta['panoid'], meta['heading'])] - return BaseDataset(info, meta)[split] + ds = BaseDataset(info, meta) + if split: + return ds[split] + else: + return ds + diff --git a/detection/data/base.py b/detection/data/base.py index 261b086..3f7370b 100644 --- a/detection/data/base.py +++ b/detection/data/base.py @@ -28,6 +28,8 @@ class BaseDataset(Dataset, if not _is_path(file_path): return None image_pil = Image.open(file_path).convert('RGB') + # scale larger images to fit + # image_pil.thumnail((640,640), Image.Resampling.LANCZOS) image_np = np.array(image_pil) return image_np diff --git a/detection/lightning/detection.py b/detection/lightning/detection.py index 71875d3..0947b30 100644 --- a/detection/lightning/detection.py +++ b/detection/lightning/detection.py @@ -64,6 +64,7 @@ class DetectionTask(pl.LightningModule, TFLogger): padding = self.hparams.get("padding", 10) if self.hparams.get('visualize', False) or self.hparams.get("deploy", False): for i, (sample, pred) in enumerate(zip(batch, preds)): + # print(i, sample, pred) instances = pred['instances'] boxes = instances.get('pred_boxes').tensor class_id = instances.get('pred_classes') @@ -94,6 +95,7 @@ class DetectionTask(pl.LightningModule, TFLogger): boxes_nms = torch.clip(boxes_nms, 0, 640) for j in range(len(scores_nms)): + # print("- ", j) instances = Instances((640, 640)) class_id_numpy = class_id_nms.to("cpu").numpy()[j] box_numpy = boxes_nms.to("cpu").numpy()[j] @@ -130,6 +132,7 @@ class DetectionTask(pl.LightningModule, TFLogger): with open(json_save_path, 'w') as fp: json.dump(data, fp) else: + # print("save img", f"{self._output_save_dir}/{batch_nb}_{i}.jpg") img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg") self.evaluator.process(batch, preds) @@ -162,8 +165,12 @@ class DetectionTask(pl.LightningModule, TFLogger): def test_dataloader(self): if self.hparams.get('deploy', False): - dataset = load_dataset(self.hparams['dataset_name']) - df = pd.read_csv(self.hparams['deploy_meta_path']).query("downloaded == True") + dataset = get_dataset(None)#load_dataset(self.hparams['deploy_meta_path']) + # TODO: apprently it does ignore much of the loaded dataset + # can we just remove/restructure the 'load_dataset' call + df = pd.read_csv(self.hparams['deploy_meta_path']) + print(df) + df = df.query("downloaded == True") df["image_id"] = df['save_path'] df["gsv_image_path"] = df['save_path'] df['annotations'] = "[]" diff --git a/detection/main.py b/detection/main.py index 5bbaf97..4f5d88e 100644 --- a/detection/main.py +++ b/detection/main.py @@ -1,6 +1,7 @@ import os import fire from pytorch_lightning import Trainer +import torch from util.nni import run_nni from util import init_exp_folder, Args @@ -110,6 +111,7 @@ def test(ckpt_path, trainer.test(task) + def nni(): run_nni(train, test) diff --git a/requirements.txt b/requirements.txt index 6fb90df..f68f999 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ pytorch-ignite scikit-learn==0.23.2 seaborn==0.10.1 segmentation-models-pytorch -torch==1.8.1+cu102 -torchvision==0.9.1+cu102 +torch +torchvision