Skip to content

Commit

Permalink
Allow editing workflows on CPUs (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
punitda authored Aug 21, 2024
1 parent 55b1154 commit 6480e12
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 5 deletions.
36 changes: 35 additions & 1 deletion backend/src/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import asyncio
import os
import json
Expand All @@ -15,6 +14,7 @@
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
import httpx
from httpx import ReadTimeout


from dotenv import load_dotenv
Expand Down Expand Up @@ -171,6 +171,40 @@ async def delete_app(app_id: str):
return {"app_id": app_id, "deleted": True}


@app.get("/apps/{app_name}/workflow-urls", dependencies=[Depends(verify_api_key)])
async def get_workflow_urls(app_name: str):
try:
workspace = await run_modal_command("modal profile current")
edit_url = f"https://{workspace}--{app_name}-editingworkflow-get-tunnel-url.modal.run"
run_url = f"https://{workspace}--{app_name}-comfyworkflow-ui.modal.run"

logger.info("GET request to url %s", edit_url)

# Set a 30-second timeout
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(edit_url)
response.raise_for_status()
result = response.json()
result["run_url"] = run_url
logger.info("Tunnel url %s", result)
logger.info("Run url %s", run_url)
return result
except httpx.ReadTimeout as e:
logger.error(
"Request timed out while making a request to %s", e.request.url)
raise HTTPException(
status_code=504, detail="Request timed out while fetching edit URL") from e
except httpx.HTTPError as e:
logger.error("HTTP %d %s error occurred while making a request to %s",
response.status_code, response.reason_phrase, e.request.url)
raise HTTPException(
status_code=500, detail=f"Failed to fetch edit URL: {str(e)}") from e
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
raise HTTPException(
status_code=500, detail="Internal server error") from e


async def deploy_app(payload: CreateAppPayload):
folder_path = f"/app/builds/{payload.machine_name}"
cp_process = await asyncio.create_subprocess_exec("cp", "-r", "/app/src/template", folder_path)
Expand Down
56 changes: 53 additions & 3 deletions backend/src/template/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
from config import config

from modal import (App, Image, web_server, build, Secret)
from modal import (App, Image, web_server, build,
Secret, method, forward, web_endpoint, Queue)
from helpers import (models_volume, MODELS_PATH, MOUNT_PATH,
download_models, unzip_insight_face_models)

Expand All @@ -21,7 +22,6 @@
.apt_install("git")
.pip_install(dependencies)
.run_commands("comfy --skip-prompt install --nvidia")
.run_commands("comfy --version")
.copy_local_file(f"{current_directory}/custom_nodes.json", "/root/")
.run_commands("comfy --skip-prompt node install-deps --deps=/root/custom_nodes.json")
.copy_local_file(f"{current_directory}/models.json", "/root/")
Expand All @@ -45,7 +45,7 @@
@app.cls(
gpu=gpu_config,
image=comfyui_image,
timeout=300,
timeout=idle_timeout,
container_idle_timeout=idle_timeout,
allow_concurrent_inputs=100,
# Restrict to 1 container because we want to our ComfyUI session state
Expand Down Expand Up @@ -74,3 +74,53 @@ def _run_comfyui_server(self, port=8188):
@web_server(8188, startup_timeout=60)
def ui(self):
self._run_comfyui_server()


@app.cls(
cpu=4.0,
memory=16384,
image=comfyui_image,
timeout=idle_timeout,
)
class EditingWorkflow:
@build()
def download(self):
with open("/root/models.json", 'r', encoding='utf-8') as file:
models = json.load(file)
downloaded = download_models(models, os.environ["CIVITAI_TOKEN"])
models_volume.commit()
if downloaded:
print(
"Copying models to correct directory - This might take a few more seconds")
shutil.copytree(
MODELS_PATH, "/root/comfy/ComfyUI/models", dirs_exist_ok=True)
print("Models copied!!")
unzip_insight_face_models()

@method()
def run_comfy_in_tunnel(self, q):
with forward(8888) as tunnel:
url = tunnel.url
print(f"Starting ComfyUI at {url}")
q.put(url)
subprocess.run(
[
"comfy",
"--skip-prompt",
"launch",
"--",
"--cpu",
"--listen",
"0.0.0.0",
"--port",
"8888",
],
check=False
)

@web_endpoint(method="GET")
def get_tunnel_url(self):
with Queue.ephemeral() as q:
self.run_comfy_in_tunnel.spawn(q)
url = q.get()
return {"edit_url": url}
59 changes: 59 additions & 0 deletions web/app/components/ui/alert.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import * as React from "react"
import { cva, type VariantProps } from "class-variance-authority"

import { cn } from "~/lib/utils"

const alertVariants = cva(
"relative w-full rounded-lg border p-4 [&>svg~*]:pl-7 [&>svg+div]:translate-y-[-3px] [&>svg]:absolute [&>svg]:left-4 [&>svg]:top-4 [&>svg]:text-foreground",
{
variants: {
variant: {
default: "bg-background text-foreground",
destructive:
"border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive",
},
},
defaultVariants: {
variant: "default",
},
}
)

const Alert = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement> & VariantProps<typeof alertVariants>
>(({ className, variant, ...props }, ref) => (
<div
ref={ref}
role="alert"
className={cn(alertVariants({ variant }), className)}
{...props}
/>
))
Alert.displayName = "Alert"

const AlertTitle = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLHeadingElement>
>(({ className, ...props }, ref) => (
<h5
ref={ref}
className={cn("mb-1 font-medium leading-none tracking-tight", className)}
{...props}
/>
))
AlertTitle.displayName = "AlertTitle"

const AlertDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("text-sm [&_p]:leading-relaxed", className)}
{...props}
/>
))
AlertDescription.displayName = "AlertDescription"

export { Alert, AlertTitle, AlertDescription }
108 changes: 108 additions & 0 deletions web/app/routes/app.$appName.edit/route.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { LoaderFunctionArgs, json } from "@remix-run/node";
import { Link, useLoaderData } from "@remix-run/react";
import { useEffect, useState } from "react";
import { Alert, AlertDescription, AlertTitle } from "~/components/ui/alert";
import { X } from "lucide-react";

import LoadingIndicator from "~/components/loading-indicator";

export async function loader({ params }: LoaderFunctionArgs) {
const appName = params.appName;
const url = `${process.env.APP_BUILDER_API_BASE_URL}/apps/${appName}/workflow-urls`;

try {
const response = await fetch(url, {
method: "GET",
headers: {
X_API_KEY: process.env.APP_BUILDER_API_KEY!,
},
});

if (!response.ok) {
throw new Error("Failed to fetch edit URL");
}

const data = await response.json();
console.log("Edit workflow url", data["edit_url"]);
console.log("Run workflow url", data["run_url"]);
return json({
editUrl: data["edit_url"] as string,
runUrl: data["run_url"] as string,
});
} catch (error) {
console.error("Error fetching edit workflow URL:", error);
return json({ error: "Failed to load edit workflow URL" }, { status: 500 });
}
}

export default function AppEditPage() {
const data = useLoaderData<typeof loader>();
const [isLoading, setIsLoading] = useState(true);
const [showAlert, setShowAlert] = useState(true);

// Once the edit url is loaded, wait 15 seconds before setting isLoading to false
// This is to prevent the iframe from loading too quickly and giving error because the tunnel is not ready
useEffect(() => {
if ("editUrl" in data) {
const timer = setTimeout(() => {
setIsLoading(false);
}, 15000);

return () => clearTimeout(timer);
}
}, [data]);

if ("error" in data) {
return <div className="text-rose-500">{data.error}</div>;
}

return (
<div className="h-screen flex flex-col relative">
<div className="flex-grow relative">
{isLoading ? (
<LoadingIndicator />
) : (
<>
{data.editUrl && (
<iframe
title={data.editUrl}
src={data.editUrl}
className="w-full h-full border-0"
/>
)}
{showAlert && (
<div className="absolute top-4 left-4 right-4 z-10">
<Alert variant="default" className="pr-12 relative">
<AlertTitle>Heads up!</AlertTitle>
<AlertDescription>
You can use this page to edit your workflows. It runs on CPU
to avoid GPU costs while editing your workflows.Please save
the workflow file before closing the page.
<br />
Please use this{" "}
<Link
to={data.runUrl}
className="underline"
rel="noopener noreferrer"
target="_blank"
>
link
</Link>{" "}
to run your workflows on GPUs
</AlertDescription>
<button
onClick={() => setShowAlert(false)}
className="absolute top-2 right-2 p-1 rounded-full hover:bg-gray-200 transition-colors"
aria-label="Close alert"
>
<X size={16} />
</button>
</Alert>
</div>
)}
</>
)}
</div>
</div>
);
}
2 changes: 1 addition & 1 deletion web/app/routes/apps/route.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ function AppsLayout({ apps }: AppsLayoutProps) {
<td className="py-4 pl-4 pr-8 sm:pl-6 lg:pl-8">
<div className="truncate text-sm font-medium leading-6 text-primary/80 underline">
<Link
to={app.url}
to={`/app/${app.description}/edit`}
rel="noopener noreferrer"
target="_blank"
>
Expand Down

0 comments on commit 6480e12

Please sign in to comment.