Use poetry, make compatible newer lightning version

This commit is contained in:
Ruben van de Ven 2024-02-29 14:38:18 +01:00
parent bbb9a14792
commit ebf52bfa07
5 changed files with 3849 additions and 9 deletions

2
.gitignore vendored
View file

@ -139,3 +139,5 @@ notebook/
# Images # Images
*.jpg *.jpg
sandbox/
lightning_logs/

View file

@ -2,7 +2,8 @@
import json import json
import os import os
from os.path import join from os.path import join
from pytorch_lightning.loggers.test_tube import TestTubeLogger # from pytorch_lightning.loggers.test_tube import TestTubeLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
@ -12,12 +13,13 @@ def get_ckpt_dir(save_path, exp_name):
def get_ckpt_callback(save_path, exp_name, monitor="val_loss", mode="min"): def get_ckpt_callback(save_path, exp_name, monitor="val_loss", mode="min"):
ckpt_dir = os.path.join(save_path, exp_name, "ckpts") ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
return ModelCheckpoint(filepath=ckpt_dir, return ModelCheckpoint(dirpath=ckpt_dir,
save_top_k=1, save_top_k=1,
verbose=True, verbose=True,
monitor=monitor, monitor=monitor,
mode=mode, mode=mode,
prefix='') # prefix=''
)
def get_early_stop_callback(patience=10): def get_early_stop_callback(patience=10):
@ -29,6 +31,6 @@ def get_early_stop_callback(patience=10):
def get_logger(save_path, exp_name): def get_logger(save_path, exp_name):
exp_dir = os.path.join(save_path, exp_name) exp_dir = os.path.join(save_path, exp_name)
return TestTubeLogger(save_dir=exp_dir, return CSVLogger(save_dir=exp_dir,
name='lightning_logs', name='lightning_logs',
version="0") version="0")

View file

@ -20,7 +20,8 @@ def train(save_dir=C.SANDBOX_PATH,
gpus=1, gpus=1,
pretrained=True, pretrained=True,
batch_size=8, batch_size=8,
accelerator="ddp", accelerator="gpu",
strategy="ddp",
gradient_clip_val=0.5, gradient_clip_val=0.5,
max_epochs=100, max_epochs=100,
learning_rate=1e-5, learning_rate=1e-5,
@ -36,13 +37,14 @@ def train(save_dir=C.SANDBOX_PATH,
Args: Args:
save_dir: Path to save the checkpoints and logs save_dir: Path to save the checkpoints and logs
exp_name: Name of the experiment exp_name: Name of the experiment
model: Model name model: Model name ("mask_rcnn","faster_rcnn","retinanet","rpn","fast_rcnn", see 'detection/models/detection/detectron.py')
gpus: int. (ie: 2 gpus) gpus: int. (ie: 2 gpus)
OR list to specify which GPUs [0, 1] OR '0,1' OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus OR '-1' / -1 to use all available gpus
pretrained: Whether or not to use the pretrained model pretrained: Whether or not to use the pretrained model
num_classes: Number of classes num_classes: Number of classes
accelerator: Distributed computing mode accelerator: Supports passing different accelerator types (cpu, gpu, tpu, ipu, hpu, mps, auto)
strategy: Supports different training strategies with aliases as well custom strategies (e.g. "ddp")
gradient_clip_val: Clip value of gradient norm gradient_clip_val: Clip value of gradient norm
limit_train_batches: Proportion of training data to use limit_train_batches: Proportion of training data to use
max_epochs: Max number of epochs max_epochs: Max number of epochs
@ -63,15 +65,16 @@ def train(save_dir=C.SANDBOX_PATH,
task = get_task(args) task = get_task(args)
trainer = Trainer(gpus=gpus, trainer = Trainer(gpus=gpus,
accelerator=accelerator, accelerator=accelerator,
strategy=strategy,
logger=get_logger(save_dir, exp_name), logger=get_logger(save_dir, exp_name),
callbacks=[get_early_stop_callback(patience), callbacks=[get_early_stop_callback(patience),
get_ckpt_callback(save_dir, exp_name, monitor="mAP", mode="max")], get_ckpt_callback(save_dir, exp_name, monitor="mAP", mode="max")],
weights_save_path=os.path.join(save_dir, exp_name), default_root_dir=os.path.join(save_dir, exp_name),
gradient_clip_val=gradient_clip_val, gradient_clip_val=gradient_clip_val,
limit_train_batches=limit_train_batches, limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches, limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches, limit_test_batches=limit_test_batches,
weights_summary=weights_summary, # weights_summary=weights_summary,
max_epochs=max_epochs) max_epochs=max_epochs)
trainer.fit(task) trainer.fit(task)
return save_dir, exp_name return save_dir, exp_name

3781
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

52
pyproject.toml Normal file
View file

@ -0,0 +1,52 @@
[tool.poetry]
name = "Surveilling-surveillance"
version = "0.1.0"
description = ""
authors = ["Stanford Computational Policy Lab"]
readme = "README.md"
[tool.poetry.dependencies]
# python can be ^3.7 if we describe the right toch dependencies
#python = "~3.9"
#this is basically the same, but != 3.9.7 because of dep 'streamlit'
python = ">=3.9,<3.9.7 || >3.9.7,<3.10"
# last version supporting np.object
numpy = "1.23.4"
#numpy = "^1.27.2"
torch = [
{ version="1.12.1" },
{ url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp39-cp39-linux_x86_64.whl", markers = "python_version ~= '3.9' and sys_platform == 'linux'" }
]
torchvision = [
{ version="0.13.1" },
{ url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp39-cp39-linux_x86_64.whl", markers = "python_version ~= '3.9' and sys_platform == 'linux'" }
]
scikit-image = "^0.20.0"
#PyYAML = "5.2"
#scipy = "^1.10.0"
#timm = "0.1.20"
tensorboardx = "1.9"
scikit-learn = "^1.3.0"
#pytorch-lightning = "1.1.4"
# work around https://lightning.ai/forums/t/attribute-error-no-sync-params/2258
pytorch-lightning = "~1.8"
test-tube = "0.7.1"
tqdm = ">=4.36.1"
pretrainedmodels = "^0.7.4"
fire = "^0.5.0"
streamlit = "^1.25.0"
albumentations = "^1.3.1"
imgaug = "^0.4.0"
pytorch-ignite = "^0.4.12"
seaborn = "^0.12.2"
segmentation-models-pytorch = "^0.3.3"
osmnx = "^1.6.0"
geopy = "^2.3.0"
coloredlogs = "^15.0.1"
nni = "^3.0"
pillow="~9.5"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"