wait until checkpoint is finished
This commit is contained in:
parent
0948a3f89c
commit
7f8c423450
1 changed files with 15 additions and 2 deletions
|
@ -148,21 +148,33 @@ class Sampler(object):
|
||||||
|
|
||||||
|
|
||||||
class Checker(object):
|
class Checker(object):
|
||||||
def __init__(self, ckpt_path, callback, interval=60):
|
def __init__(self, ckpt_path, callback, wait_for_file=5, interval=60):
|
||||||
self._cached_stamp = 0
|
self._cached_stamp = 0
|
||||||
self.filename = ckpt_path
|
self.filename = ckpt_path
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
|
self.wait_for_file = wait_for_file
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
while True:
|
while True:
|
||||||
stamp = os.stat(self.filename).st_mtime
|
stamp = os.stat(self.filename).st_mtime
|
||||||
if stamp != self._cached_stamp:
|
if stamp != self._cached_stamp:
|
||||||
|
while True:
|
||||||
|
# try to wait until checkpoint is fully written
|
||||||
|
previous_stamp = stamp
|
||||||
|
time.sleep(self.wait_for_file)
|
||||||
|
stamp = os.stat(self.filename).st_mtime
|
||||||
|
if stamp != previous_stamp:
|
||||||
|
print(f"File is still changing. Waiting {self.wait_for_file} seconds.")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
self._cached_stamp = stamp
|
self._cached_stamp = stamp
|
||||||
# file has changed, so do something...
|
# file has changed, so do something...
|
||||||
print(f"{self.__class__.__name__}: Detected a new file at "
|
print(f"{self.__class__.__name__}: Detected a new file at "
|
||||||
f"{self.filename}, calling back.")
|
f"{self.filename}, calling back.")
|
||||||
self.callback()
|
self.callback()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
time.sleep(self.interval)
|
time.sleep(self.interval)
|
||||||
|
|
||||||
|
@ -173,6 +185,7 @@ def run(prompts_path="scripts/prompts/prompts-with-wings.txt",
|
||||||
W=None,
|
W=None,
|
||||||
C=4,
|
C=4,
|
||||||
F=8,
|
F=8,
|
||||||
|
wait_for_file=5,
|
||||||
interval=60):
|
interval=60):
|
||||||
|
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
|
@ -197,7 +210,7 @@ def run(prompts_path="scripts/prompts/prompts-with-wings.txt",
|
||||||
shape = [C, H//F, W//F]
|
shape = [C, H//F, W//F]
|
||||||
sampler = Sampler(out_dir, ckpt_path, cfg_path, prompts_path, shape=shape)
|
sampler = Sampler(out_dir, ckpt_path, cfg_path, prompts_path, shape=shape)
|
||||||
|
|
||||||
checker = Checker(ckpt_path, sampler, interval=interval)
|
checker = Checker(ckpt_path, sampler, wait_for_file=wait_for_file, interval=interval)
|
||||||
checker.check()
|
checker.check()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue