Use poetry, make compatible newer lightning version
This commit is contained in:
parent
bbb9a14792
commit
ebf52bfa07
5 changed files with 3849 additions and 9 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -139,3 +139,5 @@ notebook/
|
|||
|
||||
# Images
|
||||
*.jpg
|
||||
sandbox/
|
||||
lightning_logs/
|
|
@ -2,7 +2,8 @@
|
|||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||
# from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
|
||||
|
@ -12,12 +13,13 @@ def get_ckpt_dir(save_path, exp_name):
|
|||
|
||||
def get_ckpt_callback(save_path, exp_name, monitor="val_loss", mode="min"):
|
||||
ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
|
||||
return ModelCheckpoint(filepath=ckpt_dir,
|
||||
return ModelCheckpoint(dirpath=ckpt_dir,
|
||||
save_top_k=1,
|
||||
verbose=True,
|
||||
monitor=monitor,
|
||||
mode=mode,
|
||||
prefix='')
|
||||
# prefix=''
|
||||
)
|
||||
|
||||
|
||||
def get_early_stop_callback(patience=10):
|
||||
|
@ -29,6 +31,6 @@ def get_early_stop_callback(patience=10):
|
|||
|
||||
def get_logger(save_path, exp_name):
|
||||
exp_dir = os.path.join(save_path, exp_name)
|
||||
return TestTubeLogger(save_dir=exp_dir,
|
||||
return CSVLogger(save_dir=exp_dir,
|
||||
name='lightning_logs',
|
||||
version="0")
|
||||
|
|
|
@ -20,7 +20,8 @@ def train(save_dir=C.SANDBOX_PATH,
|
|||
gpus=1,
|
||||
pretrained=True,
|
||||
batch_size=8,
|
||||
accelerator="ddp",
|
||||
accelerator="gpu",
|
||||
strategy="ddp",
|
||||
gradient_clip_val=0.5,
|
||||
max_epochs=100,
|
||||
learning_rate=1e-5,
|
||||
|
@ -36,13 +37,14 @@ def train(save_dir=C.SANDBOX_PATH,
|
|||
Args:
|
||||
save_dir: Path to save the checkpoints and logs
|
||||
exp_name: Name of the experiment
|
||||
model: Model name
|
||||
model: Model name ("mask_rcnn","faster_rcnn","retinanet","rpn","fast_rcnn", see 'detection/models/detection/detectron.py')
|
||||
gpus: int. (ie: 2 gpus)
|
||||
OR list to specify which GPUs [0, 1] OR '0,1'
|
||||
OR '-1' / -1 to use all available gpus
|
||||
pretrained: Whether or not to use the pretrained model
|
||||
num_classes: Number of classes
|
||||
accelerator: Distributed computing mode
|
||||
accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”)
|
||||
strategy: Supports different training strategies with aliases as well custom strategies (e.g. "ddp")
|
||||
gradient_clip_val: Clip value of gradient norm
|
||||
limit_train_batches: Proportion of training data to use
|
||||
max_epochs: Max number of epochs
|
||||
|
@ -63,15 +65,16 @@ def train(save_dir=C.SANDBOX_PATH,
|
|||
task = get_task(args)
|
||||
trainer = Trainer(gpus=gpus,
|
||||
accelerator=accelerator,
|
||||
strategy=strategy,
|
||||
logger=get_logger(save_dir, exp_name),
|
||||
callbacks=[get_early_stop_callback(patience),
|
||||
get_ckpt_callback(save_dir, exp_name, monitor="mAP", mode="max")],
|
||||
weights_save_path=os.path.join(save_dir, exp_name),
|
||||
default_root_dir=os.path.join(save_dir, exp_name),
|
||||
gradient_clip_val=gradient_clip_val,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=limit_val_batches,
|
||||
limit_test_batches=limit_test_batches,
|
||||
weights_summary=weights_summary,
|
||||
# weights_summary=weights_summary,
|
||||
max_epochs=max_epochs)
|
||||
trainer.fit(task)
|
||||
return save_dir, exp_name
|
||||
|
|
3781
poetry.lock
generated
Normal file
3781
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
52
pyproject.toml
Normal file
52
pyproject.toml
Normal file
|
@ -0,0 +1,52 @@
|
|||
[tool.poetry]
|
||||
name = "Surveilling-surveillance"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Stanford Computational Policy Lab"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
# python can be ^3.7 if we describe the right toch dependencies
|
||||
#python = "~3.9"
|
||||
#this is basically the same, but != 3.9.7 because of dep 'streamlit'
|
||||
python = ">=3.9,<3.9.7 || >3.9.7,<3.10"
|
||||
# last version supporting np.object
|
||||
numpy = "1.23.4"
|
||||
#numpy = "^1.27.2"
|
||||
torch = [
|
||||
{ version="1.12.1" },
|
||||
{ url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp39-cp39-linux_x86_64.whl", markers = "python_version ~= '3.9' and sys_platform == 'linux'" }
|
||||
]
|
||||
torchvision = [
|
||||
{ version="0.13.1" },
|
||||
{ url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp39-cp39-linux_x86_64.whl", markers = "python_version ~= '3.9' and sys_platform == 'linux'" }
|
||||
]
|
||||
scikit-image = "^0.20.0"
|
||||
|
||||
#PyYAML = "5.2"
|
||||
#scipy = "^1.10.0"
|
||||
#timm = "0.1.20"
|
||||
tensorboardx = "1.9"
|
||||
scikit-learn = "^1.3.0"
|
||||
#pytorch-lightning = "1.1.4"
|
||||
# work around https://lightning.ai/forums/t/attribute-error-no-sync-params/2258
|
||||
pytorch-lightning = "~1.8"
|
||||
test-tube = "0.7.1"
|
||||
tqdm = ">=4.36.1"
|
||||
pretrainedmodels = "^0.7.4"
|
||||
fire = "^0.5.0"
|
||||
streamlit = "^1.25.0"
|
||||
albumentations = "^1.3.1"
|
||||
imgaug = "^0.4.0"
|
||||
pytorch-ignite = "^0.4.12"
|
||||
seaborn = "^0.12.2"
|
||||
segmentation-models-pytorch = "^0.3.3"
|
||||
osmnx = "^1.6.0"
|
||||
geopy = "^2.3.0"
|
||||
coloredlogs = "^15.0.1"
|
||||
nni = "^3.0"
|
||||
pillow="~9.5"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
Loading…
Reference in a new issue