From e5b276bcf91a1d3694d276db159efcbdb1d885ea Mon Sep 17 00:00:00 2001 From: Patrick Esser Date: Fri, 22 Jul 2022 09:50:39 +0000 Subject: [PATCH] give me that global step --- scripts/printckpt.py | 16 ++++++++++++++++ scripts/prune-ckpt.py | 4 +++- 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 scripts/printckpt.py diff --git a/scripts/printckpt.py b/scripts/printckpt.py new file mode 100644 index 0000000..b9824d7 --- /dev/null +++ b/scripts/printckpt.py @@ -0,0 +1,16 @@ +import os +import torch +import fire + + +def printit(p): + print(f"printin' in path: {p}") + size_initial = os.path.getsize(p) + nsd = dict() + sd = torch.load(p, map_location="cpu") + if "global_step" in sd: + print(f"This is global step {sd['global_step']}.") + + +if __name__ == "__main__": + fire.Fire(printit) diff --git a/scripts/prune-ckpt.py b/scripts/prune-ckpt.py index 26d237b..e79c137 100644 --- a/scripts/prune-ckpt.py +++ b/scripts/prune-ckpt.py @@ -14,6 +14,8 @@ def prune_it(p): nsd[k] = sd[k] else: print(f"removing optimizer states for path {p}") + if "global_step" in sd: + print(f"This is global step {sd['global_step']}.") fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" print(f"saving pruned checkpoint at: {fn}") torch.save(nsd, fn) @@ -24,4 +26,4 @@ def prune_it(p): if __name__ == "__main__": fire.Fire(prune_it) - print("done.") \ No newline at end of file + print("done.")