diff --git a/scripts/printckpt.py b/scripts/printckpt.py new file mode 100644 index 0000000..b9824d7 --- /dev/null +++ b/scripts/printckpt.py @@ -0,0 +1,16 @@ +import os +import torch +import fire + + +def printit(p): + print(f"printin' in path: {p}") + size_initial = os.path.getsize(p) + nsd = dict() + sd = torch.load(p, map_location="cpu") + if "global_step" in sd: + print(f"This is global step {sd['global_step']}.") + + +if __name__ == "__main__": + fire.Fire(printit) diff --git a/scripts/prune-ckpt.py b/scripts/prune-ckpt.py index 26d237b..e79c137 100644 --- a/scripts/prune-ckpt.py +++ b/scripts/prune-ckpt.py @@ -14,6 +14,8 @@ def prune_it(p): nsd[k] = sd[k] else: print(f"removing optimizer states for path {p}") + if "global_step" in sd: + print(f"This is global step {sd['global_step']}.") fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" print(f"saving pruned checkpoint at: {fn}") torch.save(nsd, fn) @@ -24,4 +26,4 @@ def prune_it(p): if __name__ == "__main__": fire.Fire(prune_it) - print("done.") \ No newline at end of file + print("done.")