Skip to content

Commit

Permalink
feat: comfyui stop (#435)
Browse files Browse the repository at this point in the history
* feat: comfyui stop

* fix: pipeline stop

* fix: nit

* fix: process restart

* fix: timeout issue
  • Loading branch information
varshith15 authored Feb 27, 2025
1 parent fa6dc7c commit c80e471
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
7 changes: 7 additions & 0 deletions runner/app/live/pipelines/comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,10 @@ def update_params(self, **params):
logging.info(f"ComfyUI Pipeline Prompt: {new_params.prompt}")
self.client.set_prompt(new_params.prompt)
self.params = new_params

#TODO: This is a hack to stop the ComfyStreamClient. Use the comfystream api to stop the client in 0.0.2
async def stop(self):
logging.info("Stopping ComfyUI pipeline")
if self.client.comfy_client.is_running:
await self.client.comfy_client.__aexit__(None, None, None)
logging.info("ComfyUI pipeline stopped")
7 changes: 7 additions & 0 deletions runner/app/live/pipelines/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ def update_params(self, **params):
**params: Implementation-specific parameters
"""
pass

async def stop(self):
"""Stop the pipeline.
Called once when the pipeline is no longer needed.
"""
pass
16 changes: 13 additions & 3 deletions runner/app/live/streamer/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ def __init__(self, pipeline_name: str):
self.start_time = 0.0

async def stop(self):
await asyncio.to_thread(self._stop_sync)
self._stop_sync()

def _stop_sync(self):
self.done.set()

if not self.process.is_alive():
logging.info("Process already not alive")
return

logging.info("Terminating pipeline process")
self.process.terminate()

stopped = False
try:
self.process.join(timeout=5)
self.process.join(timeout=10)
stopped = True
except Exception as e:
logging.error(f"Process join error: {e}")
Expand Down Expand Up @@ -100,6 +100,7 @@ def get_recent_logs(self, n=10) -> list[str]:

def process_loop(self):
self._setup_logging()
pipeline = None

def report_error(error_msg: str):
error_event = {
Expand Down Expand Up @@ -170,6 +171,15 @@ def report_error(error_msg: str):
report_error(f"Error processing frame: {e}")
except Exception as e:
report_error(f"Error in process run method: {e}")
finally:
self._cleanup_pipeline(pipeline)

def _cleanup_pipeline(self, pipeline):
if pipeline is not None:
try:
asyncio.get_event_loop().run_until_complete(pipeline.stop())
except Exception as e:
logging.error(f"Error stopping pipeline: {e}")

def _setup_logging(self):
level = (
Expand Down

0 comments on commit c80e471

Please sign in to comment.