2022-09-23 08:20:41 +00:00
|
|
|
import torch
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-09-23 08:21:53 +00:00
|
|
|
# Make a version of the checkpoint with only ema weights (around 4GB)
|
2022-09-23 08:20:41 +00:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--original_ckpt", help="full size checkpoint file")
|
|
|
|
parser.add_argument("--output_path", help="filename for ema only checkpoint")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
print(f"loading from {args.original_ckpt}")
|
|
|
|
d = torch.load(args.original_ckpt, map_location="cpu")
|
|
|
|
|
|
|
|
new_d = {"state_dict": {}}
|
|
|
|
ema_state = {k: v for k, v in d["state_dict"].items() if not k.startswith("model.diffusion_model")}
|
|
|
|
new_d["state_dict"] = ema_state
|
|
|
|
|
|
|
|
print(f"saving to {args.output_path}")
|
2022-09-23 08:21:53 +00:00
|
|
|
torch.save(new_d, args.output_path)
|