give me that global step

This commit is contained in:
Patrick Esser 2022-07-22 09:50:39 +00:00
parent d9c9747122
commit e5b276bcf9
2 changed files with 19 additions and 1 deletions

16
scripts/printckpt.py Normal file
View file

@ -0,0 +1,16 @@
import os
import torch
import fire
def printit(p):
print(f"printin' in path: {p}")
size_initial = os.path.getsize(p)
nsd = dict()
sd = torch.load(p, map_location="cpu")
if "global_step" in sd:
print(f"This is global step {sd['global_step']}.")
if __name__ == "__main__":
fire.Fire(printit)

View file

@ -14,6 +14,8 @@ def prune_it(p):
nsd[k] = sd[k]
else:
print(f"removing optimizer states for path {p}")
if "global_step" in sd:
print(f"This is global step {sd['global_step']}.")
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt"
print(f"saving pruned checkpoint at: {fn}")
torch.save(nsd, fn)