wait until checkpoint is finished

This commit is contained in:
Patrick Esser 2022-07-13 07:29:43 +00:00 committed by pesser
parent 0948a3f89c
commit 7f8c423450

View file

@ -148,21 +148,33 @@ class Sampler(object):
class Checker(object): class Checker(object):
def __init__(self, ckpt_path, callback, interval=60): def __init__(self, ckpt_path, callback, wait_for_file=5, interval=60):
self._cached_stamp = 0 self._cached_stamp = 0
self.filename = ckpt_path self.filename = ckpt_path
self.callback = callback self.callback = callback
self.interval = interval self.interval = interval
self.wait_for_file = wait_for_file
def check(self): def check(self):
while True: while True:
stamp = os.stat(self.filename).st_mtime stamp = os.stat(self.filename).st_mtime
if stamp != self._cached_stamp: if stamp != self._cached_stamp:
while True:
# try to wait until checkpoint is fully written
previous_stamp = stamp
time.sleep(self.wait_for_file)
stamp = os.stat(self.filename).st_mtime
if stamp != previous_stamp:
print(f"File is still changing. Waiting {self.wait_for_file} seconds.")
else:
break
self._cached_stamp = stamp self._cached_stamp = stamp
# file has changed, so do something... # file has changed, so do something...
print(f"{self.__class__.__name__}: Detected a new file at " print(f"{self.__class__.__name__}: Detected a new file at "
f"{self.filename}, calling back.") f"{self.filename}, calling back.")
self.callback() self.callback()
else: else:
time.sleep(self.interval) time.sleep(self.interval)
@ -173,6 +185,7 @@ def run(prompts_path="scripts/prompts/prompts-with-wings.txt",
W=None, W=None,
C=4, C=4,
F=8, F=8,
wait_for_file=5,
interval=60): interval=60):
if out_dir is None: if out_dir is None:
@ -197,7 +210,7 @@ def run(prompts_path="scripts/prompts/prompts-with-wings.txt",
shape = [C, H//F, W//F] shape = [C, H//F, W//F]
sampler = Sampler(out_dir, ckpt_path, cfg_path, prompts_path, shape=shape) sampler = Sampler(out_dir, ckpt_path, cfg_path, prompts_path, shape=shape)
checker = Checker(ckpt_path, sampler, interval=interval) checker = Checker(ckpt_path, sampler, wait_for_file=wait_for_file, interval=interval)
checker.check() checker.check()