disable ckpt on error in debug mode

This commit is contained in:
Patrick Esser 2022-06-11 18:34:28 -04:00
parent 34f9f3867e
commit d06c2277b0

10
main.py
View file

@ -238,7 +238,8 @@ class DataModuleFromConfig(pl.LightningDataModule):
class SetupCallback(Callback): class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): def __init__(self, resume, now, logdir, ckptdir, cfgdir, config,
lightning_config, debug):
super().__init__() super().__init__()
self.resume = resume self.resume = resume
self.now = now self.now = now
@ -247,9 +248,10 @@ class SetupCallback(Callback):
self.cfgdir = cfgdir self.cfgdir = cfgdir
self.config = config self.config = config
self.lightning_config = lightning_config self.lightning_config = lightning_config
self.debug = debug
def on_keyboard_interrupt(self, trainer, pl_module): def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0: if not self.debug and trainer.global_rank == 0:
print("Summoning checkpoint.") print("Summoning checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt") ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path) trainer.save_checkpoint(ckpt_path)
@ -702,6 +704,7 @@ if __name__ == "__main__":
"cfgdir": cfgdir, "cfgdir": cfgdir,
"config": config, "config": config,
"lightning_config": lightning_config, "lightning_config": lightning_config,
"debug": opt.debug,
} }
}, },
"image_logger": { "image_logger": {
@ -822,7 +825,8 @@ if __name__ == "__main__":
try: try:
trainer.fit(model, data) trainer.fit(model, data)
except Exception: except Exception:
melk() if not opt.debug:
melk()
raise raise
if not opt.no_test and not trainer.interrupted: if not opt.no_test and not trainer.interrupted:
trainer.test(model, data) trainer.test(model, data)