Skip to content

Commit

Permalink
Extract just the execution code for publishing
Browse files Browse the repository at this point in the history
A minimal change set was made to allow the code to function after
changes to the execution model. The changes unrelated to workflow
orchestration have been pulled out so that the code is viable for
general use once again.
  • Loading branch information
AustinMroz committed Oct 17, 2024
1 parent e03d06d commit 535ea76
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 76 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-workflowcheckpointing"
description = "Automatically creates checkpoints during workflow execution. If If an workflow is canceled or ComfyUI crashes mid-execution, then these checkpoints are used when the workflow is re-queued to resume execution with minimal progress loss."
version = "1.0.1"
version = "1.1.0"
license = { file = "LICENSE" }

[project.urls]
Expand Down
269 changes: 194 additions & 75 deletions workflowcheckpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import comfy.samplers
import execution
import server
import heapq

SAMPLER_NODES = ["SamplerCustom", "KSampler", "KSamplerAdvanced", "SamplerCustomAdvanced"]

Expand All @@ -23,6 +24,7 @@ async def get_header():
return {'Salad-Api-Key': os.environ['SALAD_API_KEY']}
global SALAD_TOKEN
if SALAD_TOKEN is None:
assert 'SALAD_MACHINE_ID' in os.environ, "SALAD_API_KEY must be provided if not deployed"
async with aiohttp.ClientSession() as session:
async with session.get('http://169.254.169.254:80/v1/token') as r:
SALAD_TOKEN =(await r.json())['jwt']
Expand Down Expand Up @@ -66,10 +68,8 @@ async def delete_file(self, s, url, semaphore):
async with s.delete(url, headers=await get_header()) as r:
await r.text()
async def _reset(self, s, uid):
base_url = '/organizations/' + ORGANIZATION +'/files'
checkpoint_base = '/'.join([base_url, uid, 'checkpoint'])
checkpoint_base = 'https://storage-api.salad.com' + checkpoint_base
async with s.get(base_url, headers=await get_header()) as r:
async with s.get(base_url_path, headers=await get_header()) as r:
js = await r.json()
files = js['files']
checkpoints = list(filter(lambda x: x['url'].startswith(checkpoint_base), files))
Expand All @@ -82,30 +82,147 @@ async def _reset(self, s, uid):
async def process_requests(self):
headers = await get_header()
async with aiohttp.ClientSession('https://storage-api.salad.com') as session:
while True:
if self.do_reset != False:
await self._reset(session, self.do_reset)
self.do_reset = False
if self.active_request is None:
await asyncio.sleep(.1)
else:
req = self.active_request
fd = aiohttp.FormData({'file': req[1]})
async with session.put(req[0], headers=headers, data=fd) as r:

#We don't care about result, but must still await it
await r.text()
with self.mutex:
if not self.queue_high.empty():
self.active_request = self.queue_high.get()
try:
while True:
if self.do_reset != False:
await self._reset(session, self.do_reset)
self.do_reset = False
if self.active_request is None:
await asyncio.sleep(.1)
else:
if self.low is not None:
self.active_request = self.low
self.low = None
req = self.active_request
fd = aiohttp.FormData({'file': req[1]})
async with session.put(req[0], headers=headers, data=fd) as r:

#We don't care about result, but must still await it
await r.text()
with self.mutex:
if not self.queue_high.empty():
self.active_request = self.queue_high.get()
else:
self.active_request = None
if self.low is not None:
self.active_request = self.low
self.low = None
else:
self.active_request = None
except:
#Exceptions from event loop get swallowed and kill the loop
import traceback
traceback.print_exc()
raise
class FetchQueue:
"""Modified priority queue implementation that tracks inflight and allows priority modification"""
def __init__(self):
self.lock = threading.RLock()
self.queue = []# queue contains priority, url, future
self.count = 0
self.consumed = {}
self.new_items = asyncio.Event()
def update_priority(self, i, priority):
#lock must already be acquired
future = self.queue[i][3]
if priority < self.queue[i][0]:
#priority is increased, invalidate old
self.queue[i] = (self.queue[i][0], self.queue[i][1], None, None)
heapq.heappush(self.queue, (priority, self.count, item, future))
self.count += 1
def requeue(self, future, item, dec_priority=1):
with self.lock:
priority = self.consumed[item][1] - dec_priority
heapq.heappush(self.queue, (priority, self.count, future, None))
self.count += 1
self.new_items.set()
def enqueue_checked(self, item, priority):
with self.lock:
if item in self.consumed:
#TODO: Also update in queue
#TODO: if complete check etag?
self.consumed[item][1] = min(self.consumed[item][1], priority)
return self.consumed[item][0]
for i in range(len(self.queue)):
if self.queue[i][2] == item:
future = self.queue[i][3]
self.update_priority(i, priority)
return future
future = asyncio.Future()
heapq.heappush(self.queue, (priority, self.count, item, future))
self.count += 1
self.new_items.set()
return future
async def get(self):
while True:
await self.new_items.wait()
with self.lock:
priority, _, item, future = heapq.heappop(self.queue)
if len(self.queue) == 0:
self.new_items.clear()
if item is not None:
if isinstance(item, str):
self.consumed[item] = [future, priority]
return priority, item, future
else:
#item is future
item.set_result(True)

class FetchLoop:
def __init__(self):
self.queue = FetchQueue()
self.semaphore = asyncio.Semaphore(5)
self.cs = aiohttp.ClientSession()
event_loop = server.PromptServer.instance.loop
self.process_loop = event_loop.create_task(self.loop())
os.makedirs("fetches", exist_ok=True)
async def loop(self):
event_loop = server.PromptServer.instance.loop
while True:
await self.semaphore.acquire()
event_loop.create_task(self.fetch(*(await self.queue.get())))
def reset(self, url):
with self.queue.lock:
if url in self.queue.consumed:
self.queue.consumed.pop(url)
hashloc = os.path.join('fetches', string_hash(url))
if os.path.exists(hashloc):
os.remove(hashloc)
def enqueue(self, url, priority=0):
return self.queue.enqueue_checked(url, priority)
async def fetch(self, priority, url, future):
chunk_size = 2**25 #32MB
headers = {}
if url.startswith(base_url):
headers.update(await get_header())
filename = os.path.join('fetches', string_hash(url))
try:
async with self.cs.get(url, headers=headers) as r:
with open(filename, 'wb') as f:
async for chunk in r.content.iter_chunked(chunk_size):
f.write(chunk)
if not r.content.is_eof():
awaken = asyncio.Future()
self.queue.requeue(awaken, url)
await awaken
future.set_result(filename)
except:
future.set_result(None)
raise
finally:
self.semaphore.release()
return
fetch_loop = FetchLoop()
async def prepare_file(url, path, priority):
hashloc = os.path.join('fetches', string_hash(url))
if not os.path.exists(hashloc):
hashloc = await fetch_loop.enqueue(url, priority)
if os.path.exists(path):
os.remove(path)
os.makedirs(os.path.split(path)[0], exist_ok=True)
#TODO consider if symlinking would be better
os.link(hashloc, path)

ORGANIZATION = os.environ.get('SALAD_ORGANIZATION', None)
if ORGANIZATION is not None:
base_url_path = '/organizations/' + ORGANIZATION +'/files'
base_url = 'https://storage-api.salad.com' + base_url_path
class NetCheckpoint:
def __init__(self):
self.requestloop = RequestLoop()
Expand Down Expand Up @@ -139,10 +256,12 @@ def reset(self, unique_id=None):
if unique_id is not None:
if os.path.exists(f"input/checkpoint/{unique_id}.checkpoint"):
os.remove(f"input/checkpoint/{unique_id}.checkpoint")
fetch_loop.reset('/'.join([base_url, self.uid, 'checkpoint', f'{unique_id}.checkpoint']))
return
os.makedirs("input/checkpoint", exist_ok=True)
for file in os.listdir("input/checkpoint"):
os.remove(os.path.join("input/checkpoint", file))
fetch_loop.reset('/'.join([base_url, self.uid, 'checkpoint', file]))

class FileCheckpoint:
def store(self, unique_id, tensors, metadata, priority=0):
Expand Down Expand Up @@ -178,49 +297,53 @@ def file_hash(filename):
while n := f.readinto(b):
h.update(b)
return h.hexdigest()
def string_hash(s):
h = hashlib.sha256()
h.update(s.encode('utf-8'))
return h.hexdigest()
def fetch_remote_file(url, filepath, file_hash=None):
assert filepath.find("..") == -1, "Paths may not contain .."
return prepare_file(url, filepath, -1)

async def fetch_remote_file(session, file, semaphore):
filename = os.path.join("input", file['filepath'])
assert filename.find("..") == -1, "Paths may not contain .."
if os.path.exists(filename) and 'hash' in file and file_hash(filename) == file['hash']:
return
if file['url'].startswith('https://storage-api.salad.com/'):
headers = await get_header()
else:
headers = {}
async with semaphore:
async with session.get(file['url'], headers=headers) as r:
with open(filename, 'wb') as fd:
async for chunk in r.content.iter_chunked(2**16):
fd.write(chunk)

async def fetch_remote_files(remote_files, uid=None):
#TODO: Add requested support for zip files?
async with aiohttp.ClientSession() as s:
base_url = 'https://storage-api.salad.com/organizations/' + ORGANIZATION +'/files'
if uid is not None:
checkpoint_base = '/'.join([base_url, uid, 'checkpoint'])
async with s.get(base_url, headers=await get_header()) as r:
js = await r.json()
files = js['files']
checkpoints = list(filter(lambda x: x['url'].startswith(checkpoint_base), files))
for cp in checkpoints:
cp['filepath'] = os.path.join('checkpoint',
cp['url'][len(checkpoint_base)+1:])
remote_files = itertools.chain(remote_files, checkpoints)
semaphore = asyncio.Semaphore(5)
fetches = [asyncio.create_task(fetch_remote_file(s, f, semaphore)) for f in remote_files]
if len(fetches) > 0:
await asyncio.gather(*fetches)
if uid is not None:
checkpoint_base = '/'.join([base_url_path, uid, 'checkpoint'])
checkpoint_base = 'https://storage-api.salad.com'+ checkpoint_base
async with fetch_loop.cs.get(base_url, headers=await get_header()) as r:
js = await r.json()
files = js['files']
checkpoints = list(filter(lambda x: x['url'].startswith(checkpoint_base), files))
for cp in checkpoints:
cp['filepath'] = os.path.join('input/checkpoint',
cp['url'][len(checkpoint_base)+1:])
remote_files = itertools.chain(remote_files, checkpoints)
fetches = []
for f in remote_files:
fetches.append(fetch_remote_file(f['url'],f['filepath'], f.get('file_hash', None)))
if len(fetches) > 0:
await asyncio.gather(*fetches)

completion_futures = {}
def add_future(json_data):
index = max(completion_futures.keys())
json_data['extra_data']['completion_future'] = index
return json_data
server.PromptServer.instance.add_on_prompt_handler(add_future)

prompt_route = next(filter(lambda x: x.path == '/prompt' and x.method == 'POST',
server.PromptServer.instance.routes))
original_post_prompt = prompt_route.handler
async def post_prompt_remote(request):
if 'dump_req' in os.environ:
with open('resp-dump.txt', 'wb') as f:
f.write(await request.read())
import sys
sys.exit()
json_data = await request.json()
if "SALAD_ORGANIZATION" in os.environ:
extra_data = json_data.get("extra_data", {})
#NOTE: Rendered obsolete by existing infrastructure, can be pruned
remote_files = extra_data.get("remote_files", [])
uid = json_data.get("client_id", 'local')
checkpoint.uid = uid
Expand Down Expand Up @@ -261,42 +384,32 @@ def callback(self, step, denoised, x, total_steps):
if step == int(os.environ['FORCE_CRASH_AT']):
raise Exception("Simulated Crash")

original_recursive_execute = execution.recursive_execute
original_recursive_execute = execution.execute
def recursive_execute_injection(*args):

unique_id = args[3]
class_type = args[1][unique_id]['class_type']
class_type = args[1].get_node(unique_id)['class_type']
extra_data = args[4]
if 'checkpoints' in extra_data:
checkpoint.update(extra_data.pop('checkpoints'))
if 'prompt_checked' not in args[4]:
metadata = checkpoint.get('prompt')[1]
if metadata is None or json.loads(metadata['prompt']) != args[1]:
checkpoint.reset()
checkpoint.store('prompt', {'x': torch.ones(1)},
{'prompt': json.dumps(args[1])}, priority=2)
args[4]['prompt_checked'] = True
if class_type in SAMPLER_NODES:
data, metadata = checkpoint.get(unique_id)
if metadata is not None and 'step' in metadata:
args[1][unique_id]['inputs']['latent_image'] = ['checkpointed'+unique_id, 0]
args[2]['checkpointed'+unique_id] = [[{'samples': data['x']}]]
args[1].get_node(unique_id)['inputs']['latent_image'] = ['checkpointed'+unique_id, 0]
args[2].outputs.set('checkpointed'+unique_id, [[{'samples': data['x']}]])
elif metadata is not None and 'completed' in metadata:
outputs = json.loads(metadata['completed'])
for x in range(len(outputs)):
if outputs[x] == 'tensor':
outputs[x] = list(data[str(x)])
elif outputs[x] == 'latent':
outputs[x] = [{'samples': l} for l in data[str(x)]]
args[2][unique_id] = outputs
args[2].outputs.set(unique_id, outputs)
return True, None, None

res = original_recursive_execute(*args)
#Conditionally save node output
#TODO: determine which non-sampler nodes are worth saving
if class_type in SAMPLER_NODES and unique_id in args[2]:
if class_type in SAMPLER_NODES and args[2].outputs.get(unique_id) is not None:
data = {}
outputs = args[2][unique_id].copy()
outputs = args[2].outputs.get(unique_id).copy()
for x in range(len(outputs)):
if isinstance(outputs[x][0], torch.Tensor):
data[str(x)] = torch.stack(outputs[x])
Expand All @@ -306,9 +419,15 @@ def recursive_execute_injection(*args):
outputs[x] = 'latent'
checkpoint.store(unique_id, data, {'completed': json.dumps(outputs)}, priority=1)
return res
original_execute = execution.PromptExecutor.execute
def execute_injection(*args, **kwargs):
metadata = checkpoint.get('prompt')[1]
if metadata is None or json.loads(metadata['prompt']) != args[1]:
checkpoint.reset()
checkpoint.store('prompt', {'x': torch.ones(1)},
{'prompt': json.dumps(args[1])}, priority=2)
original_execute(*args, **kwargs)

comfy.samplers.KSAMPLER = CheckpointSampler
execution.recursive_execute = recursive_execute_injection

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
execution.execute = recursive_execute_injection
execution.PromptExecutor.execute = execute_injection

0 comments on commit 535ea76

Please sign in to comment.