Fix running with --deploy --deploy-meta-path

This commit is contained in:
Ruben van de Ven 2024-04-17 15:38:30 +02:00
parent cf0a761a73
commit d317689790
5 changed files with 21 additions and 5 deletions

View file

@ -18,4 +18,9 @@ def get_dataset(split="train"):
# meta['image_path'] = [f'data/image/{panoid}_{heading}.jpg' for panoid, heading in zip(meta['panoid'], meta['heading'])] # meta['image_path'] = [f'data/image/{panoid}_{heading}.jpg' for panoid, heading in zip(meta['panoid'], meta['heading'])]
return BaseDataset(info, meta)[split] ds = BaseDataset(info, meta)
if split:
return ds[split]
else:
return ds

View file

@ -28,6 +28,8 @@ class BaseDataset(Dataset,
if not _is_path(file_path): if not _is_path(file_path):
return None return None
image_pil = Image.open(file_path).convert('RGB') image_pil = Image.open(file_path).convert('RGB')
# scale larger images to fit
# image_pil.thumnail((640,640), Image.Resampling.LANCZOS)
image_np = np.array(image_pil) image_np = np.array(image_pil)
return image_np return image_np

View file

@ -64,6 +64,7 @@ class DetectionTask(pl.LightningModule, TFLogger):
padding = self.hparams.get("padding", 10) padding = self.hparams.get("padding", 10)
if self.hparams.get('visualize', False) or self.hparams.get("deploy", False): if self.hparams.get('visualize', False) or self.hparams.get("deploy", False):
for i, (sample, pred) in enumerate(zip(batch, preds)): for i, (sample, pred) in enumerate(zip(batch, preds)):
# print(i, sample, pred)
instances = pred['instances'] instances = pred['instances']
boxes = instances.get('pred_boxes').tensor boxes = instances.get('pred_boxes').tensor
class_id = instances.get('pred_classes') class_id = instances.get('pred_classes')
@ -94,6 +95,7 @@ class DetectionTask(pl.LightningModule, TFLogger):
boxes_nms = torch.clip(boxes_nms, 0, 640) boxes_nms = torch.clip(boxes_nms, 0, 640)
for j in range(len(scores_nms)): for j in range(len(scores_nms)):
# print("- ", j)
instances = Instances((640, 640)) instances = Instances((640, 640))
class_id_numpy = class_id_nms.to("cpu").numpy()[j] class_id_numpy = class_id_nms.to("cpu").numpy()[j]
box_numpy = boxes_nms.to("cpu").numpy()[j] box_numpy = boxes_nms.to("cpu").numpy()[j]
@ -130,6 +132,7 @@ class DetectionTask(pl.LightningModule, TFLogger):
with open(json_save_path, 'w') as fp: with open(json_save_path, 'w') as fp:
json.dump(data, fp) json.dump(data, fp)
else: else:
# print("save img", f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg") img_box.save(f"{self._output_save_dir}/{batch_nb}_{i}.jpg")
self.evaluator.process(batch, preds) self.evaluator.process(batch, preds)
@ -162,8 +165,12 @@ class DetectionTask(pl.LightningModule, TFLogger):
def test_dataloader(self): def test_dataloader(self):
if self.hparams.get('deploy', False): if self.hparams.get('deploy', False):
dataset = load_dataset(self.hparams['dataset_name']) dataset = get_dataset(None)#load_dataset(self.hparams['deploy_meta_path'])
df = pd.read_csv(self.hparams['deploy_meta_path']).query("downloaded == True") # TODO: apprently it does ignore much of the loaded dataset
# can we just remove/restructure the 'load_dataset' call
df = pd.read_csv(self.hparams['deploy_meta_path'])
print(df)
df = df.query("downloaded == True")
df["image_id"] = df['save_path'] df["image_id"] = df['save_path']
df["gsv_image_path"] = df['save_path'] df["gsv_image_path"] = df['save_path']
df['annotations'] = "[]" df['annotations'] = "[]"

View file

@ -1,6 +1,7 @@
import os import os
import fire import fire
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
import torch
from util.nni import run_nni from util.nni import run_nni
from util import init_exp_folder, Args from util import init_exp_folder, Args
@ -110,6 +111,7 @@ def test(ckpt_path,
trainer.test(task) trainer.test(task)
def nni(): def nni():
run_nni(train, test) run_nni(train, test)

View file

@ -13,5 +13,5 @@ pytorch-ignite
scikit-learn==0.23.2 scikit-learn==0.23.2
seaborn==0.10.1 seaborn==0.10.1
segmentation-models-pytorch segmentation-models-pytorch
torch==1.8.1+cu102 torch
torchvision==0.9.1+cu102 torchvision