diff --git a/README.md b/README.md index c3bc6012..d275306a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,10 @@ - [Config](#config) * [Upload image to AWS S3](#upload-image-to-aws-s3) - [Use the Docker image on RunPod](#use-the-docker-image-on-runpod) +- [API specification](#api-specification) + * [JSON Payload Structure](#json-payload-structure) + * [Fields](#fields) + + ["input.images" details](#inputimages-details) - [Interact with your RunPod API](#interact-with-your-runpod-api) * [Health status](#health-status) * [Generate an image](#generate-an-image) @@ -24,7 +28,6 @@ - [Local testing](#local-testing) * [Setup](#setup) + [Setup for Windows](#setup-for-windows) - * [Activate virtual env](#activate-virtual-env) * [Test: handler](#test-handler) * [Test: docker image](#test-docker-image) - [Automatically deploy to Docker hub with Github Actions](#automatically-deploy-to-docker-hub-with-github-actions) @@ -43,6 +46,7 @@ ## Features - Run any [ComfyUI](https://github.com/comfyanonymous/ComfyUI) workflow to generate an image +- Provide input images as base64-encoded string - Generated image is either: - Returned as base64-encoded string (default) - Uploaded to AWS S3 ([if AWS S3 is configured](#upload-image-to-aws-s3)) @@ -57,6 +61,10 @@ ## Config +| Environment Variable | Description | Default | +| -------------------- | ---------------------------------------------------------------------------- | ------- | +| `REFRESH_WORKER` | When you want stop the worker after each finished job to have a clean state. | `false` | + ### Upload image to AWS S3 This is only needed if you want to upload the generated picture to AWS S3. If you don't configure this, your image will be exported as base64-encoded string. @@ -97,6 +105,44 @@ This is only needed if you want to upload the generated picture to AWS S3. If yo - Click `deploy` - Your endpoint will be created, you can click on it to see the dashboard +## API specification + +The following is the required structure and format for requests made to the API. + +### JSON Payload Structure + +```json +{ + "input": { + "workflow": {}, + "images": [ + { + "name": "example_image_name.png", + "image": "base64_encoded_string" + } + ] + } +} +``` + +### Fields + +| Field Path | Type | Required | Description | +| ---------------- | ------ | -------- | ------------------------------------------------- | +| `input` | Object | Yes | The top-level object containing the request data. | +| `input.workflow` | Object | Yes | Contains the ComfyUI workflow configuration. | +| `input.images` | Array | No | An array of images. | + + +#### "input.images" details + +| Field Name | Type | Required | Description | +| ---------- | ------ | -------- | ---------------------------------------------------------------------------------------- | +| `name` | String | No | The name of the image. Please use the same name in your workflow to reference the image. | +| `image` | String | No | A base64 encoded string of the image. | + + + ## Interact with your RunPod API - In the [User Settings](https://www.runpod.io/console/serverless/user/settings) click on `API Keys` and then on the `API Key` button @@ -105,6 +151,8 @@ This is only needed if you want to upload the generated picture to AWS S3. If yo - Replace `` with your key - Replace `` with the ID of the endpoint, you find that when you click on your endpoint, it's part of the URLs shown at the bottom of the first box + + ### Health status ```bash @@ -113,16 +161,16 @@ curl -H "Authorization: Bearer " https://api.runpod.ai/v2/ ### Generate an image -You can ether create a new job async by using /run or a sync by using runsync. The example here is using a sync job and waits until the response is delivered. +You can either create a new job async by using `/run` or a sync by using runsync. The example here is using a sync job and waits until the response is delivered. -The API expects a JSON in this form, where `prompt` is the [workflow from ComfyUI, exported as JSON](#how-to-get-the-workflow-from-comfyui): +The API expects a JSON in this form, where `workflow` is the [workflow from ComfyUI, exported as JSON](#how-to-get-the-workflow-from-comfyui): ```json { "input": { - "prompt": { + "workflow": { // ComfyUI workflow - } + }, } } ``` @@ -132,7 +180,7 @@ Please also take a look at the [test_input.json](./test_input.json) to see how t #### Example request with cURL ```bash -curl -X POST -H "Authorization: Bearer " -H "Content-Type: application/json" -d '{"input":{"prompt":{"3":{"inputs":{"seed":1337,"steps":20,"cfg":8,"sampler_name":"euler","scheduler":"normal","denoise":1,"model":["4",0],"positive":["6",0],"negative":["7",0],"latent_image":["5",0]},"class_type":"KSampler"},"4":{"inputs":{"ckpt_name":"sd_xl_base_1.0.safetensors"},"class_type":"CheckpointLoaderSimple"},"5":{"inputs":{"width":512,"height":512,"batch_size":1},"class_type":"EmptyLatentImage"},"6":{"inputs":{"text":"beautiful scenery nature glass bottle landscape, , purple galaxy bottle,","clip":["4",1]},"class_type":"CLIPTextEncode"},"7":{"inputs":{"text":"text, watermark","clip":["4",1]},"class_type":"CLIPTextEncode"},"8":{"inputs":{"samples":["3",0],"vae":["4",2]},"class_type":"VAEDecode"},"9":{"inputs":{"filename_prefix":"ComfyUI","images":["8",0]},"class_type":"SaveImage"}}}}' https://api.runpod.ai/v2//runsync +curl -X POST -H "Authorization: Bearer " -H "Content-Type: application/json" -d '{"input":{"workflow":{"3":{"inputs":{"seed":1337,"steps":20,"cfg":8,"sampler_name":"euler","scheduler":"normal","denoise":1,"model":["4",0],"positive":["6",0],"negative":["7",0],"latent_image":["5",0]},"class_type":"KSampler"},"4":{"inputs":{"ckpt_name":"sd_xl_base_1.0.safetensors"},"class_type":"CheckpointLoaderSimple"},"5":{"inputs":{"width":512,"height":512,"batch_size":1},"class_type":"EmptyLatentImage"},"6":{"inputs":{"text":"beautiful scenery nature glass bottle landscape, , purple galaxy bottle,","clip":["4",1]},"class_type":"CLIPTextEncode"},"7":{"inputs":{"text":"text, watermark","clip":["4",1]},"class_type":"CLIPTextEncode"},"8":{"inputs":{"samples":["3",0],"vae":["4",2]},"class_type":"VAEDecode"},"9":{"inputs":{"filename_prefix":"ComfyUI","images":["8",0]},"class_type":"SaveImage"}}}}' https://api.runpod.ai/v2//runsync # Response with AWS S3 bucket configuration # {"delayTime":2188,"executionTime":2297,"id":"sync-c0cd1eb2-068f-4ecf-a99a-55770fc77391-e1","output":{"message":"https://bucket.s3.region.amazonaws.com/10-23/sync-c0cd1eb2-068f-4ecf-a99a-55770fc77391-e1/c67ad621.png","status":"success"},"status":"COMPLETED"} @@ -150,7 +198,7 @@ curl -X POST -H "Authorization: Bearer " -H "Content-Type: application/ - Close the `Settings` - In the menu, click on the `Save (API Format)` button, which will download a file named `workflow_api.json` -You can now take the content of this file and put it into your `prompt` when interacting with the API. +You can now take the content of this file and put it into your `workflow` when interacting with the API. ## Build the image @@ -166,6 +214,7 @@ Both tests will use the data from [test_input.json](./test_input.json), so make - Make sure you have Python >= 3.10 - Create a virtual environment: `python -m venv venv` +- Activate the virtual environment: `.\venv\Scripts\activate` (Windows) or `source ./venv/bin/activate` (Mac / Linux) - Install the dependencies: `pip install -r requirements.txt` #### Setup for Windows @@ -188,10 +237,6 @@ To run the Docker image on Windows, we need to have WSL2 and a Linux distro (lik - Add your user to the `docker` group, so that you can use Docker without `sudo`: `sudo usermod -aG docker $USER` -### Activate virtual env - -`.\venv\Scripts\activate` - ### Test: handler - Run all tests: `python -m unittest discover` diff --git a/requirements.txt b/requirements.txt index c2663d31..1e4fbbdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -runpod \ No newline at end of file +runpod==1.3.6 \ No newline at end of file diff --git a/src/rp_handler.py b/src/rp_handler.py index c156751d..6faeb50b 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -7,6 +7,7 @@ import os import requests import base64 +from io import BytesIO # Time to wait between API check attempts in milliseconds COMFY_API_AVAILABLE_INTERVAL_MS = 50 @@ -15,9 +16,55 @@ # Time to wait between poll attempts in milliseconds COMFY_POLLING_INTERVAL_MS = 250 # Maximum number of poll attempts -COMFY_POLLING_MAX_RETRIES = 100 +COMFY_POLLING_MAX_RETRIES = 500 # Host where ComfyUI is running COMFY_HOST = "127.0.0.1:8188" +# Enforce a clean state after each job is done +# see https://docs.runpod.io/docs/handler-additional-controls#refresh-worker +REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true" + + +def validate_input(job_input): + """ + Validates the input for the handler function. + + Args: + job_input (dict): The input data to validate. + + Returns: + tuple: A tuple containing the validated data and an error message, if any. + The structure is (validated_data, error_message). + """ + # Validate if job_input is provided + if job_input is None: + return None, "Please provide input" + + # Check if input is a string and try to parse it as JSON + if isinstance(job_input, str): + try: + job_input = json.loads(job_input) + except json.JSONDecodeError: + return None, "Invalid JSON format in input" + + # Validate 'workflow' in input + workflow = job_input.get("workflow") + if workflow is None: + return None, "Missing 'workflow' parameter" + + # Validate 'images' in input, if provided + images = job_input.get("images") + if images is not None: + if not isinstance(images, list) or not all( + "name" in image and "image" in image for image in images + ): + return ( + None, + "'images' must be a list of objects with 'name' and 'image' keys", + ) + + # Return validated data and no error + return {"workflow": workflow, "images": images}, None + def check_server(url, retries=50, delay=500): """ @@ -53,17 +100,73 @@ def check_server(url, retries=50, delay=500): return False -def queue_prompt(prompt): +def upload_images(images): """ - Queue a prompt to be processed by ComfyUI + Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint. Args: - prompt (dict): A dictionary containing the prompt to be processed + images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string. + server_address (str): The address of the ComfyUI server. Returns: - dict: The JSON response from ComfyUI after processing the prompt + list: A list of responses from the server for each image upload. """ - data = json.dumps(prompt).encode("utf-8") + if not images: + return {"status": "success", "message": "No images to upload", "details": []} + + responses = [] + upload_errors = [] + + print(f"runpod-worker-comfy - image(s) upload") + + for image in images: + name = image["name"] + image_data = image["image"] + blob = base64.b64decode(image_data) + + # Prepare the form data + files = { + "image": (name, BytesIO(blob), "image/png"), + "overwrite": (None, "true"), + } + + # POST request to upload the image + response = requests.post(f"http://{COMFY_HOST}/upload/image", files=files) + if response.status_code != 200: + upload_errors.append(f"Error uploading {name}: {response.text}") + else: + responses.append(f"Successfully uploaded {name}") + + if upload_errors: + print(f"runpod-worker-comfy - image(s) upload with errors") + return { + "status": "error", + "message": "Some images failed to upload", + "details": upload_errors, + } + + print(f"runpod-worker-comfy - image(s) upload complete") + return { + "status": "success", + "message": "All images uploaded successfully", + "details": responses, + } + + +def queue_workflow(workflow): + """ + Queue a workflow to be processed by ComfyUI + + Args: + workflow (dict): A dictionary containing the workflow to be processed + + Returns: + dict: The JSON response from ComfyUI after processing the workflow + """ + + # The top level element "prompt" is required by ComfyUI + data = json.dumps({"prompt": workflow}).encode("utf-8") + req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) return json.loads(urllib.request.urlopen(req).read()) @@ -94,8 +197,9 @@ def base64_encode(img_path): """ with open(img_path, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode("utf-8") - return f"data:image/png;base64,{encoded_string}" - + return f"{encoded_string}" + + def process_output_images(outputs, job_id): """ This function takes the "outputs" from image generation and the job ID, @@ -126,7 +230,7 @@ def process_output_images(outputs, job_id): """ # The path where ComfyUI stores the generated images - COMFY_OUTPUT_PATH = os.environ.get('COMFY_OUTPUT_PATH', "/comfyui/output") + COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") output_images = {} @@ -140,20 +244,26 @@ def process_output_images(outputs, job_id): # expected image output folder local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}" + print(f"runpod-worker-comfy - {local_image_path}") + # The image is in the output folder if os.path.exists(local_image_path): - print("runpod-worker-comfy - the image exists in the output folder") - - if os.environ.get('BUCKET_ENDPOINT_URL', False): + if os.environ.get("BUCKET_ENDPOINT_URL", False): # URL to image in AWS S3 image = rp_upload.upload_image(job_id, local_image_path) + print( + "runpod-worker-comfy - the image was generated and uploaded to AWS S3" + ) else: # base64 image image = base64_encode(local_image_path) + print( + "runpod-worker-comfy - the image was generated and converted to base64" + ) return { - "status": "success", - "message": image, + "status": "success", + "message": image, } else: print("runpod-worker-comfy - the image does not exist in the output folder") @@ -161,7 +271,7 @@ def process_output_images(outputs, job_id): "status": "error", "message": f"the image does not exist in the specified output folder: {local_image_path}", } - + def handler(job): """ @@ -178,6 +288,15 @@ def handler(job): """ job_input = job["input"] + # Make sure that the input is valid + validated_data, error_message = validate_input(job_input) + if error_message: + return {"error": error_message} + + # Extract validated data + workflow = validated_data["workflow"] + images = validated_data.get("images") + # Make sure that the ComfyUI API is available check_server( f"http://{COMFY_HOST}", @@ -185,29 +304,19 @@ def handler(job): COMFY_API_AVAILABLE_INTERVAL_MS, ) - # Validate input - if job_input is None: - return {"error": "Please provide the 'prompt'"} + # Upload images if they exist + upload_result = upload_images(images) - # Is JSON? - if isinstance(job_input, dict): - prompt = job_input - # Is String? - elif isinstance(job_input, str): - try: - prompt = json.loads(job_input) - except json.JSONDecodeError: - return {"error": "Invalid JSON format in 'prompt'"} - else: - return {"error": "'prompt' must be a JSON object or a JSON-encoded string"} + if upload_result["status"] == "error": + return upload_result - # Queue the prompt + # Queue the workflow try: - queued_prompt = queue_prompt(prompt) - prompt_id = queued_prompt["prompt_id"] - print(f"runpod-worker-comfy - queued prompt with ID {prompt_id}") + queued_workflow = queue_workflow(workflow) + prompt_id = queued_workflow["prompt_id"] + print(f"runpod-worker-comfy - queued workflow with ID {prompt_id}") except Exception as e: - return {"error": f"Error queuing prompt: {str(e)}"} + return {"error": f"Error queuing workflow: {str(e)}"} # Poll for completion print(f"runpod-worker-comfy - wait until image generation is complete") @@ -229,7 +338,11 @@ def handler(job): return {"error": f"Error waiting for image generation: {str(e)}"} # Get the generated image and return it as URL in an AWS bucket or as base64 - return process_output_images(history[prompt_id].get("outputs"), job["id"]) + images_result = process_output_images(history[prompt_id].get("outputs"), job["id"]) + + result = {**images_result, "refresh_worker": REFRESH_WORKER} + + return result # Start the handler only if this script is run directly diff --git a/test_input.json b/test_input.json index 2a0380ce..1210cec3 100644 --- a/test_input.json +++ b/test_input.json @@ -1,6 +1,6 @@ { "input": { - "prompt": { + "workflow": { "3": { "inputs": { "seed": 234234, diff --git a/tests/test_rp_handler.py b/tests/test_rp_handler.py index 23e1197f..be744866 100644 --- a/tests/test_rp_handler.py +++ b/tests/test_rp_handler.py @@ -3,42 +3,94 @@ import sys import os import json +import base64 # Make sure that "src" is known and can be used to import rp_handler.py -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) from src import rp_handler # Local folder for test resources RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES = "./test_resources/images" + class TestRunpodWorkerComfy(unittest.TestCase): - @patch('rp_handler.requests.get') + def test_valid_input_with_workflow_only(self): + input_data = {"workflow": {"key": "value"}} + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNone(error) + self.assertEqual(validated_data, {"workflow": {"key": "value"}, "images": None}) + + def test_valid_input_with_workflow_and_images(self): + input_data = { + "workflow": {"key": "value"}, + "images": [{"name": "image1.png", "image": "base64string"}], + } + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNone(error) + self.assertEqual(validated_data, input_data) + + def test_input_missing_workflow(self): + input_data = {"images": [{"name": "image1.png", "image": "base64string"}]} + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNotNone(error) + self.assertEqual(error, "Missing 'workflow' parameter") + + def test_input_with_invalid_images_structure(self): + input_data = { + "workflow": {"key": "value"}, + "images": [{"name": "image1.png"}], # Missing 'image' key + } + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNotNone(error) + self.assertEqual( + error, "'images' must be a list of objects with 'name' and 'image' keys" + ) + + def test_invalid_json_string_input(self): + input_data = "invalid json" + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNotNone(error) + self.assertEqual(error, "Invalid JSON format in input") + + def test_valid_json_string_input(self): + input_data = '{"workflow": {"key": "value"}}' + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNone(error) + self.assertEqual(validated_data, {"workflow": {"key": "value"}, "images": None}) + + def test_empty_input(self): + input_data = None + validated_data, error = rp_handler.validate_input(input_data) + self.assertIsNotNone(error) + self.assertEqual(error, "Please provide input") + + @patch("rp_handler.requests.get") def test_check_server_server_up(self, mock_requests): mock_response = MagicMock() mock_response.status_code = 200 mock_requests.return_value = mock_response - result = rp_handler.check_server('http://127.0.0.1:8188', 1, 50) + result = rp_handler.check_server("http://127.0.0.1:8188", 1, 50) self.assertTrue(result) - @patch('rp_handler.requests.get') + @patch("rp_handler.requests.get") def test_check_server_server_down(self, mock_requests): mock_requests.get.side_effect = rp_handler.requests.RequestException() - result = rp_handler.check_server('http://127.0.0.1:8188', 1, 50) + result = rp_handler.check_server("http://127.0.0.1:8188", 1, 50) self.assertFalse(result) - @patch('rp_handler.urllib.request.urlopen') + @patch("rp_handler.urllib.request.urlopen") def test_queue_prompt(self, mock_urlopen): mock_response = MagicMock() mock_response.read.return_value = json.dumps({"prompt_id": "123"}).encode() mock_urlopen.return_value = mock_response - result = rp_handler.queue_prompt({"prompt": "test"}) + result = rp_handler.queue_workflow({"prompt": "test"}) self.assertEqual(result, {"prompt_id": "123"}) - @patch('rp_handler.urllib.request.urlopen') + @patch("rp_handler.urllib.request.urlopen") def test_get_history(self, mock_urlopen): # Mock response data as a JSON string - mock_response_data = json.dumps({"key": "value"}).encode('utf-8') + mock_response_data = json.dumps({"key": "value"}).encode("utf-8") # Define a mock response function for `read` def mock_read(): @@ -62,69 +114,117 @@ def mock_read(): self.assertEqual(result, {"key": "value"}) mock_urlopen.assert_called_with("http://127.0.0.1:8188/history/123") - @patch('builtins.open', new_callable=mock_open, read_data=b'test') + @patch("builtins.open", new_callable=mock_open, read_data=b"test") def test_base64_encode(self, mock_file): + test_data = base64.b64encode(b"test").decode("utf-8") + result = rp_handler.base64_encode("dummy_path") - self.assertTrue(result.startswith("data:image/png;base64,")) - @patch('rp_handler.os.path.exists') - @patch('rp_handler.rp_upload.upload_image') - @patch.dict(os.environ, {'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES}) + self.assertEqual(result, test_data) + + @patch("rp_handler.os.path.exists") + @patch("rp_handler.rp_upload.upload_image") + @patch.dict( + os.environ, {"COMFY_OUTPUT_PATH": RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES} + ) def test_bucket_endpoint_not_configured(self, mock_upload_image, mock_exists): mock_exists.return_value = True - mock_upload_image.return_value = 'simulated_uploaded/image.png' - - outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}} - job_id = '123' + mock_upload_image.return_value = "simulated_uploaded/image.png" - result = rp_handler.process_output_images(outputs, job_id) + outputs = {"node_id": {"images": [{"filename": "ComfyUI_00001_.png"}]}} + job_id = "123" - self.assertEqual(result['status'], 'success') - self.assertTrue(result['message'].startswith("data:image/png;base64,")) + result = rp_handler.process_output_images(outputs, job_id) - @patch('rp_handler.os.path.exists') - @patch('rp_handler.rp_upload.upload_image') - @patch.dict(os.environ, {'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, 'BUCKET_ENDPOINT_URL': 'http://example.com'}) + self.assertEqual(result["status"], "success") + + @patch("rp_handler.os.path.exists") + @patch("rp_handler.rp_upload.upload_image") + @patch.dict( + os.environ, + { + "COMFY_OUTPUT_PATH": RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, + "BUCKET_ENDPOINT_URL": "http://example.com", + }, + ) def test_bucket_endpoint_configured(self, mock_upload_image, mock_exists): # Mock the os.path.exists to return True, simulating that the image exists mock_exists.return_value = True # Mock the rp_upload.upload_image to return a simulated URL - mock_upload_image.return_value = 'http://example.com/uploaded/image.png' + mock_upload_image.return_value = "http://example.com/uploaded/image.png" # Define the outputs and job_id for the test - outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}} - job_id = '123' + outputs = {"node_id": {"images": [{"filename": "ComfyUI_00001_.png"}]}} + job_id = "123" # Call the function under test result = rp_handler.process_output_images(outputs, job_id) # Assertions - self.assertEqual(result['status'], 'success') - self.assertEqual(result['message'], 'http://example.com/uploaded/image.png') - mock_upload_image.assert_called_once_with(job_id, './test_resources/images/ComfyUI_00001_.png') - - - @patch('rp_handler.os.path.exists') - @patch('rp_handler.rp_upload.upload_image') - @patch.dict(os.environ, { - 'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, - 'BUCKET_ENDPOINT_URL': 'http://example.com', - 'BUCKET_ACCESS_KEY_ID': '', - 'BUCKET_SECRET_ACCESS_KEY': '' - }) - def test_bucket_image_upload_fails_env_vars_wrong_or_missing(self, mock_upload_image, mock_exists): + self.assertEqual(result["status"], "success") + self.assertEqual(result["message"], "http://example.com/uploaded/image.png") + mock_upload_image.assert_called_once_with( + job_id, "./test_resources/images/ComfyUI_00001_.png" + ) + + @patch("rp_handler.os.path.exists") + @patch("rp_handler.rp_upload.upload_image") + @patch.dict( + os.environ, + { + "COMFY_OUTPUT_PATH": RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, + "BUCKET_ENDPOINT_URL": "http://example.com", + "BUCKET_ACCESS_KEY_ID": "", + "BUCKET_SECRET_ACCESS_KEY": "", + }, + ) + def test_bucket_image_upload_fails_env_vars_wrong_or_missing( + self, mock_upload_image, mock_exists + ): # Simulate the file existing in the output path mock_exists.return_value = True # When AWS credentials are wrong or missing, upload_image should return 'simulated_uploaded/...' - mock_upload_image.return_value = 'simulated_uploaded/image.png' + mock_upload_image.return_value = "simulated_uploaded/image.png" - outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}} - job_id = '123' + outputs = {"node_id": {"images": [{"filename": "ComfyUI_00001_.png"}]}} + job_id = "123" result = rp_handler.process_output_images(outputs, job_id) # Check if the image was saved to the 'simulated_uploaded' directory - self.assertIn('simulated_uploaded', result['message']) - self.assertEqual(result['status'], 'success') \ No newline at end of file + self.assertIn("simulated_uploaded", result["message"]) + self.assertEqual(result["status"], "success") + + @patch("rp_handler.requests.post") + def test_upload_images_successful(self, mock_post): + mock_response = unittest.mock.Mock() + mock_response.status_code = 200 + mock_response.text = "Successfully uploaded" + mock_post.return_value = mock_response + + test_image_data = base64.b64encode(b"Test Image Data").decode("utf-8") + + images = [{"name": "test_image.png", "image": test_image_data}] + + responses = rp_handler.upload_images(images) + + self.assertEqual(len(responses), 3) + self.assertEqual(responses["status"], "success") + + @patch("rp_handler.requests.post") + def test_upload_images_failed(self, mock_post): + mock_response = unittest.mock.Mock() + mock_response.status_code = 400 + mock_response.text = "Error uploading" + mock_post.return_value = mock_response + + test_image_data = base64.b64encode(b"Test Image Data").decode("utf-8") + + images = [{"name": "test_image.png", "image": test_image_data}] + + responses = rp_handler.upload_images(images) + + self.assertEqual(len(responses), 3) + self.assertEqual(responses["status"], "error")