stable-diffusion-finetune/scripts/printckpt.py

19 lines
464 B
Python

import os
import torch
import fire
def printit(p):
print(f"printin' in path: {p}")
size_initial = os.path.getsize(p)
nsd = dict()
sd = torch.load(p, map_location="cpu")
if "global_step" in sd:
print(f"This is global step {sd['global_step']}.")
if "model_ema.num_updates" in sd["state_dict"]:
print(f"And we got {sd['state_dict']['model_ema.num_updates']} EMA updates.")
if __name__ == "__main__":
fire.Fire(printit)