From e2ba8b15c8ee4a2a656de1b34d785fbd912b776f Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Fri, 28 Jun 2024 21:19:10 -0700 Subject: [PATCH] feat(engine): Implement schedules (#214) * feat(engine): Add create schedule endpoint * feat(engine): Add rest of the CRUD operations for schedules * chore: Lower log level for action result * feat(cli): Add cli utils * feat(cli): Add schedules cli * feat(engine): Use config for temporal cluster queue * feat(engine): Add better error visibility on http 422 * refactor(playbook): Use slack secret over slack_channel * docs: Update api docstrings * feat(engine): Add more metdata fields to Schedule * feat(engine): Make Schedule status online by default * feat(ui): Update schedules UI * feat(engine): Add update schedules * feat(cli): Add cli update schedules --- .../workspace/canvas/trigger-node.tsx | 30 ++- .../workspace/panel/trigger-panel.tsx | 38 ++- frontend/src/types/schemas.ts | 9 +- .../aws-guardduty-to-slack.yml | 4 +- tracecat/api/app.py | 225 ++++++++++++++---- tracecat/auth/credentials.py | 19 +- tracecat/cli/_utils.py | 43 ++++ tracecat/cli/dev.py | 5 +- tracecat/cli/main.py | 3 +- tracecat/cli/schedule.py | 142 +++++++++++ tracecat/cli/workflow.py | 35 +-- tracecat/config.py | 3 + tracecat/contexts.py | 6 +- tracecat/db/schemas.py | 15 +- tracecat/dsl/dispatcher.py | 5 +- tracecat/dsl/schedules.py | 167 +++++++++++++ tracecat/dsl/workflow.py | 2 +- tracecat/identifiers/__init__.py | 13 +- tracecat/identifiers/schedules.py | 15 ++ tracecat/identifiers/workflow.py | 5 + tracecat/types/api.py | 41 +++- 21 files changed, 702 insertions(+), 123 deletions(-) create mode 100644 tracecat/cli/schedule.py create mode 100644 tracecat/dsl/schedules.py create mode 100644 tracecat/identifiers/schedules.py diff --git a/frontend/src/components/workspace/canvas/trigger-node.tsx b/frontend/src/components/workspace/canvas/trigger-node.tsx index ecbb1ab8b..d906f83b2 100644 --- a/frontend/src/components/workspace/canvas/trigger-node.tsx +++ b/frontend/src/components/workspace/canvas/trigger-node.tsx @@ -127,7 +127,7 @@ export default React.memo(function TriggerNode({ - +
Schedules @@ -136,9 +136,31 @@ export default React.memo(function TriggerNode({ - {workflow.schedules.map(({ id, cron }) => ( - - {cron} + {workflow.schedules.map(({ status, every }, idx) => ( + + +
+ + {every} + + +
+
))}
diff --git a/frontend/src/components/workspace/panel/trigger-panel.tsx b/frontend/src/components/workspace/panel/trigger-panel.tsx index 846fa0036..a665a24e4 100644 --- a/frontend/src/components/workspace/panel/trigger-panel.tsx +++ b/frontend/src/components/workspace/panel/trigger-panel.tsx @@ -265,30 +265,42 @@ export function ScheduleControls({ schedules }: { schedules: Schedule[] }) {
- +
Schedules
+ + ID + Status + Inputs + Every +
- {schedules.map(({ id, cron }) => ( - - {cron} + {schedules.map(({ id, status, inputs, every }) => ( + + {id} + {status} + {JSON.stringify(inputs)} + {every} ))} - - + + + +
diff --git a/frontend/src/types/schemas.ts b/frontend/src/types/schemas.ts index 578900458..803f2c2d9 100644 --- a/frontend/src/types/schemas.ts +++ b/frontend/src/types/schemas.ts @@ -39,8 +39,13 @@ export type Webhook = z.infer export const scheduleSchema = z .object({ status: z.enum(["online", "offline"]), - entrypoint_payload: z.record(z.any()), - cron: z.string(), + inputs: z.record(z.any()), + cron: z.string().nullish(), + every: z.string(), + offset: z.string().nullable(), + start_at: strAsDate.nullable(), + end_at: strAsDate.nullable(), + workflow_id: z.string(), }) .and(resourceSchema) export type Schedule = z.infer diff --git a/playbooks/alert_management/aws-guardduty-to-slack.yml b/playbooks/alert_management/aws-guardduty-to-slack.yml index a03466d11..bacfe4b89 100644 --- a/playbooks/alert_management/aws-guardduty-to-slack.yml +++ b/playbooks/alert_management/aws-guardduty-to-slack.yml @@ -23,7 +23,7 @@ actions: - pull_aws_guardduty_findings run_if: ${{ FN.is_empty(ACTIONS.pull_aws_guardduty_findings.result) }} args: - url: ${{ SECRETS.slack_channel.SLACK_WEBHOOK }} + url: ${{ SECRETS.slack.SLACK_WEBHOOK }} method: POST headers: Content-Type: application/json @@ -61,7 +61,7 @@ actions: # Assign each SMAC finding to a variable named `smac` for_each: ${{ for var.smac in ACTIONS.reshape_findings_into_smac.result }} args: - channel: ${{ SECRETS.slack_channel.SLACK_CHANNEL }} + channel: ${{ SECRETS.slack.SLACK_CHANNEL }} text: GuardDuty findings blocks: - type: header diff --git a/tracecat/api/app.py b/tracecat/api/app.py index 654c0f443..9217e361e 100644 --- a/tracecat/api/app.py +++ b/tracecat/api/app.py @@ -17,6 +17,7 @@ UploadFile, status, ) +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.params import Body from fastapi.responses import ORJSONResponse, StreamingResponse @@ -31,6 +32,7 @@ stream_case_completions, ) from tracecat.auth.credentials import ( + TemporaryRole, authenticate_service, authenticate_user, authenticate_user_or_service, @@ -53,11 +55,11 @@ WorkflowDefinition, WorkflowRun, ) +from tracecat.dsl import dispatcher, schedules from tracecat.dsl.common import DSLInput # TODO: Clean up API params / response "zoo" # lots of repetition and inconsistency -from tracecat.dsl.dispatcher import dispatch_workflow from tracecat.dsl.graph import RFGraph from tracecat.logging import logger from tracecat.middleware import RequestLoggingMiddleware @@ -80,6 +82,7 @@ CreateWorkflowParams, Event, EventSearchParams, + SearchScheduleParams, SearchSecretsParams, SecretResponse, ServiceCallbackAction, @@ -88,6 +91,7 @@ TriggerWorkflowRunParams, UDFArgsValidationResponse, UpdateActionParams, + UpdateScheduleParams, UpdateSecretParams, UpdateUserParams, UpdateWorkflowParams, @@ -210,6 +214,16 @@ async def tracecat_exception_handler(request: Request, exc: TracecatException): ) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Improves visiblity of 422 errors.""" + exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") + logger.error(f"{request}: {exc_str}") + return ORJSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=exc_str + ) + + @app.get("/", include_in_schema=False) def root() -> dict[str, str]: return {"message": "Hello world. I am the API."} @@ -353,7 +367,7 @@ async def incoming_webhook( logger.info(dsl_input.dump_yaml()) - asyncio.create_task(dispatch_workflow(dsl_input, wf_id=path)) + asyncio.create_task(dispatcher.dispatch_workflow(dsl_input, wf_id=path)) return {"status": "ok"} @@ -479,7 +493,7 @@ async def webhook_callback( logger.info(dsl_input.dump_yaml()) - asyncio.create_task(dispatch_workflow(dsl_input, wf_id=path)) + asyncio.create_task(dispatcher.dispatch_workflow(dsl_input, wf_id=path)) return {"status": "ok", "message": "Webhook dispatched"} case None: @@ -1060,7 +1074,7 @@ async def trigger_workflow_run( path = "workflow4" with Path(f"/app/tracecat/static/workflows/{path}.yaml").resolve().open() as f: dsl_yaml = f.read() - await dispatch_workflow(dsl_yaml) + await dispatcher.dispatch_workflow(dsl_yaml) return StartWorkflowResponse( status="ok", message="Workflow started.", id=workflow_id @@ -1153,58 +1167,100 @@ def update_webhook( # ----- Workflow Schedules ----- # -@app.get("/workflows/{workflow_id}/schedules", tags=["triggers"]) -def list_schedules( - role: Annotated[Role, Depends(authenticate_user_or_service)], - workflow_id: str, +@app.get("/schedules", tags=["schedules"]) +async def list_schedules( + role: Annotated[Role, Depends(authenticate_user)], + workflow_id: identifiers.WorkflowID | None = None, ) -> list[Schedule]: - """**[WORK IN PROGRESS]** List all schedules for a workflow.""" + """List all schedules for a workflow.""" with Session(engine) as session: - statement = select(Schedule).where( - Schedule.owner_id == role.user_id, - Schedule.workflow_id == workflow_id, - ) + statement = select(Schedule).where(Schedule.owner_id == role.user_id) + if workflow_id: + statement = statement.where(Schedule.workflow_id == workflow_id) result = session.exec(statement) - return result.all() + try: + return result.all() + except NoResultFound as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found" + ) from e -@app.post( - "/workflows/{workflow_id}/schedules", - status_code=status.HTTP_201_CREATED, - tags=["triggers"], -) -def create_schedule( - role: Annotated[Role, Depends(authenticate_user_or_service)], - workflow_id: str, +@app.post("/schedules", tags=["schedules"]) +async def create_schedule( + role: Annotated[Role, Depends(authenticate_user)], params: CreateScheduleParams, -) -> None: - """**[WORK IN PROGRESS]** Create a schedule for a workflow.""" +) -> Schedule: + """Create a schedule for a workflow.""" - schedule = Schedule( - owner_id=role.user_id, - cron=params.cron, - entrypoint_payload=params.entrypoint_payload, - entrypoint_ref=params.entrypoint_ref, - workflow_id=workflow_id, - ) - with Session(engine) as session: - session.add(schedule) - session.commit() - session.refresh(schedule) + with Session(engine) as session, logger.contextualize(role=role): + result = session.exec( + select(WorkflowDefinition) + .where(WorkflowDefinition.workflow_id == params.workflow_id) + .order_by(WorkflowDefinition.version.desc()) + ) + try: + if not (defn_data := result.first()): + raise NoResultFound("No workflow definition found for workflow ID") + except NoResultFound as e: + logger.opt(exception=e).error("Invalid workflow ID", error=e) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Invalid workflow ID" + ) from e + + schedule = Schedule( + owner_id=role.user_id, **params.model_dump(exclude_unset=True) + ) + session.refresh(defn_data) + defn = WorkflowDefinition.model_validate(defn_data) + dsl = defn.content + if params.inputs: + dsl.trigger_inputs = params.inputs + try: + # Set the role for the schedule as the tracecat-runner + with TemporaryRole( + type="service", user_id=defn.owner_id, service_id="tracecat-runner" + ) as sch_role: + handle = await schedules.create_schedule( + workflow_id=params.workflow_id, + schedule_id=schedule.id, + dsl=dsl, + every=params.every, + offset=params.offset, + start_at=params.start_at, + end_at=params.end_at, + ) + logger.info( + "Created schedule", + handle_id=handle.id, + workflow_id=params.workflow_id, + schedule_id=schedule.id, + sch_role=sch_role, + ) -@app.get("/workflows/{workflow_id}/schedules/{schedule_id}", tags=["triggers"]) + session.add(schedule) + session.commit() + session.refresh(schedule) + return schedule + except Exception as e: + session.rollback() + logger.opt(exception=e).error("Error creating schedule", error=e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating schedule", + ) from e + + +@app.get("/schedules/{schedule_id}", tags=["schedules"]) def get_schedule( - role: Annotated[Role, Depends(authenticate_user_or_service)], - schedule_id: str, - workflow_id: str, + role: Annotated[Role, Depends(authenticate_user)], + schedule_id: identifiers.ScheduleID, ) -> Schedule: - """**[WORK IN PROGRESS]** Get a schedule from a workflow.""" + """Get a schedule from a workflow.""" with Session(engine) as session: statement = select(Schedule).where( - Schedule.owner_id == role.user_id, - Schedule.id == schedule_id, - Schedule.workflow_id == workflow_id, + Schedule.owner_id == role.user_id, Schedule.id == schedule_id ) result = session.exec(statement) try: @@ -1215,18 +1271,60 @@ def get_schedule( ) from e -@app.delete("/workflows/{workflow_id}/schedules/{schedule_id}", tags=["triggers"]) -def delete_schedule( - role: Annotated[Role, Depends(authenticate_user_or_service)], - schedule_id: str, - workflow_id: str, +@app.post("/schedules/{schedule_id}", tags=["schedules"]) +async def update_schedule( + role: Annotated[Role, Depends(authenticate_user)], + schedule_id: identifiers.ScheduleID, + params: UpdateScheduleParams, +) -> Schedule: + """Update a schedule from a workflow. You cannot update the Workflow Definition, but you can update other fields.""" + with Session(engine) as session: + statement = select(Schedule).where( + Schedule.owner_id == role.user_id, Schedule.id == schedule_id + ) + result = session.exec(statement) + try: + schedule = result.one() + except NoResultFound as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found" + ) from e + + try: + # (1) Synchronize with Temporal + await schedules.update_schedule(schedule_id, params) + + # (2) Update the schedule + for key, value in params.model_dump(exclude_unset=True).items(): + # Safety: params have been validated + setattr(schedule, key, value) + + session.add(schedule) + session.commit() + session.refresh(schedule) + return schedule + except Exception as e: + session.rollback() + logger.opt(exception=e).error("Error creating schedule", error=e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating schedule", + ) from e + + +@app.delete( + "/schedules/{schedule_id}", + status_code=status.HTTP_204_NO_CONTENT, + tags=["schedules"], +) +async def delete_schedule( + role: Annotated[Role, Depends(authenticate_user)], + schedule_id: identifiers.ScheduleID, ) -> None: - """**[WORK IN PROGRESS]** Delete a schedule from a workflow.""" + """Delete a schedule from a workflow.""" with Session(engine) as session: statement = select(Schedule).where( - Schedule.owner_id == role.user_id, - Schedule.id == schedule_id, - Schedule.workflow_id == workflow_id, + Schedule.owner_id == role.user_id, Schedule.id == schedule_id ) result = session.exec(statement) try: @@ -1235,10 +1333,33 @@ def delete_schedule( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found" ) from e + + try: + await schedules.delete_schedule(schedule_id) + except Exception as e: + logger.error("Error deleting schedule", error=e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error deleting schedule", + ) from e session.delete(schedule) session.commit() +@app.get("/schedules/search", tags=["schedules"]) +def search_schedules( + role: Annotated[Role, Depends(authenticate_user)], + params: SearchScheduleParams, +) -> list[Schedule]: + """**[WORK IN PROGRESS]** Search for schedules.""" + with Session(engine) as session: + statement = select(Schedule).where(Schedule.owner_id == role.user_id) + results = session.exec(statement) + schedules = results.all() + return schedules + + +@app.post("/schedules/search") # ----- Actions ----- # diff --git a/tracecat/auth/credentials.py b/tracecat/auth/credentials.py index 7051f6b18..c05348242 100644 --- a/tracecat/auth/credentials.py +++ b/tracecat/auth/credentials.py @@ -4,8 +4,9 @@ import hashlib import os +from contextlib import contextmanager from functools import partial -from typing import Annotated, Any +from typing import Annotated, Any, Literal import httpx import orjson @@ -242,3 +243,19 @@ async def authenticate_user_or_service( if api_key: return await _get_role_from_service_key(request, api_key) raise HTTP_EXC("Could not validate credentials") + + +@contextmanager +def TemporaryRole( + type: Literal["user", "service"] = "service", + user_id: str | None = None, + service_id: str | None = None, +): + """An async context manager to authenticate a user or service.""" + prev_role = ctx_role.get() + temp_role = Role(type=type, user_id=user_id, service_id=service_id) + ctx_role.set(temp_role) + try: + yield temp_role + finally: + ctx_role.set(prev_role) diff --git a/tracecat/cli/_utils.py b/tracecat/cli/_utils.py index 5ff5a11f4..dd35d05ac 100644 --- a/tracecat/cli/_utils.py +++ b/tracecat/cli/_utils.py @@ -1,16 +1,30 @@ +from pathlib import Path + import httpx +import orjson +import rich +import typer from rich.table import Table from ._config import config def user_client() -> httpx.AsyncClient: + """Returns an asynchronous httpx client with the user's JWT token.""" return httpx.AsyncClient( headers={"Authorization": f"Bearer {config.jwt_token}"}, base_url=config.api_url, ) +def user_client_sync() -> httpx.Client: + """Returns a synchronous httpx client with the user's JWT token.""" + return httpx.Client( + headers={"Authorization": f"Bearer {config.jwt_token}"}, + base_url=config.api_url, + ) + + def dynamic_table(data: list[dict[str, str]], title: str) -> Table: # Dynamically add columns based on the keys of the JSON objects table = Table(title=title) @@ -22,3 +36,32 @@ def dynamic_table(data: list[dict[str, str]], title: str) -> Table: for item in data: table.add_row(*[str(value) for value in item.values()]) return table + + +def read_input(data: str) -> dict[str, str]: + """Read data from a file or JSON string. + + If the data starts with '@', it is treated as a file path. + Else it is treated as a JSON string. + """ + if data[0] == "@": + p = Path(data[1:]) + if not p.exists(): + raise typer.BadParameter(f"File {p} does not exist") + if p.suffix != ".json": + raise typer.BadParameter(f"File {p} is not a JSON file") + with p.open() as f: + data = f.read() + try: + return orjson.loads(data) + except orjson.JSONDecodeError as e: + raise typer.BadParameter(f"Invalid JSON: {e}") from e + + +def handle_response(res: httpx.Response) -> httpx.Response: + if res.status_code == 422: + rich.print("[red]Validation error[/red]") + rich.print(res.json()) + raise typer.Exit() + res.raise_for_status() + return res diff --git a/tracecat/cli/dev.py b/tracecat/cli/dev.py index afcdf8cc1..1f15c6d60 100644 --- a/tracecat/cli/dev.py +++ b/tracecat/cli/dev.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from ._config import config -from ._utils import user_client +from ._utils import read_input, user_client app = typer.Typer(no_args_is_help=True, help="Dev tools.") @@ -37,7 +37,8 @@ def api( data: str = typer.Option(None, "--data", "-d", help="JSON Payload to send"), ): """Commit a workflow definition to the database.""" - payload = orjson.loads(data) if data else None + + payload = read_input(data) if data else None result = asyncio.run(hit_api_endpoint(method, endpoint, payload)) rich.print("Hit the endpoint successfully!") rich.print(result, len(result)) diff --git a/tracecat/cli/main.py b/tracecat/cli/main.py index 05cd6c9e9..478a53cb9 100644 --- a/tracecat/cli/main.py +++ b/tracecat/cli/main.py @@ -1,7 +1,7 @@ import typer from dotenv import find_dotenv, load_dotenv -from . import dev, secret, workflow +from . import dev, schedule, secret, workflow load_dotenv(find_dotenv()) app = typer.Typer(no_args_is_help=True, pretty_exceptions_show_locals=False) @@ -26,6 +26,7 @@ def tracecat( app.add_typer(workflow.app, name="workflow") app.add_typer(dev.app, name="dev") app.add_typer(secret.app, name="secret") +app.add_typer(schedule.app, name="schedule") if __name__ == "__main__": typer.run(app) diff --git a/tracecat/cli/schedule.py b/tracecat/cli/schedule.py new file mode 100644 index 000000000..66866e0a0 --- /dev/null +++ b/tracecat/cli/schedule.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import asyncio + +import rich +import typer +from rich.console import Console + +from tracecat.types.api import CreateScheduleParams, UpdateScheduleParams + +from ._utils import ( + dynamic_table, + handle_response, + read_input, + user_client, + user_client_sync, +) + +app = typer.Typer(no_args_is_help=True, help="Manage schedules.") + + +@app.command(help="Create a schedule", no_args_is_help=True) +def create( + workflow_id: str = typer.Argument(..., help="Workflow ID"), + data: str = typer.Option( + None, "--data", "-d", help="JSON Payload to send (trigger context)" + ), + every: str = typer.Option( + None, "--every", "-e", help="Interval at which the schedule should run" + ), + offset: str = typer.Option( + None, "--offset", "-o", help="Offset from the start of the interval" + ), +): + """Create a new schedule.""" + + inputs = read_input(data) if data else None + + params = CreateScheduleParams( + workflow_id=workflow_id, + every=every, + offset=offset, + inputs=inputs, + ) + + with user_client_sync() as client: + res = client.post( + "/schedules", + json=params.model_dump(exclude_unset=True, exclude_none=True, mode="json"), + ) + handle_response(res) + + rich.print(res.json()) + + +@app.command(name="list", help="List all schedules") +def list_schedules( + workflow_id: str = typer.Argument(None, help="Workflow ID"), + as_table: bool = typer.Option(False, "--table", "-t", help="Display as table"), +): + """List all schedules.""" + + params = {} + if workflow_id: + params["workflow_id"] = workflow_id + with user_client_sync() as client: + res = client.get("/schedules", params=params) + handle_response(res) + + result = res.json() + if as_table: + table = dynamic_table(result, "Schedules") + Console().print(table) + else: + rich.print(result) + + +@app.command(help="Delete schedules", no_args_is_help=True) +def delete( + schedule_ids: list[str] = typer.Argument( + ..., help="IDs of the schedules to delete" + ), +): + """Delete schedules""" + + delete = typer.confirm(f"Are you sure you want to delete {schedule_ids!r}") + if not delete: + rich.print("Aborted") + return + + async def _delete(): + async with user_client() as client, asyncio.TaskGroup() as tg: + for sch_id in schedule_ids: + tg.create_task(client.delete(f"/schedules/{sch_id}")) + + asyncio.run(_delete()) + + +@app.command(help="Update a schedule", no_args_is_help=True) +def update( + schedule_id: str = typer.Argument(..., help="ID of the schedule to update."), + inputs: str = typer.Option( + None, "--data", "-d", help="JSON Payload to send (trigger context)" + ), + every: str = typer.Option( + None, "--every", help="Interval at which the schedule should run" + ), + offset: str = typer.Option( + None, "--offset", help="Offset from the start of the interval" + ), + online: bool = typer.Option(None, "--online", help="Set the schedule to online"), + offline: bool = typer.Option(None, "--offline", help="Set the schedule to offline"), +): + """Update a schedule""" + if online and offline: + raise typer.BadParameter("Cannot set both online and offline") + + params = UpdateScheduleParams( + inputs=read_input(inputs) if inputs else None, + every=every, + offset=offset, + status="online" if online else "offline", + ) + with user_client_sync() as client: + res = client.post( + f"/schedules/{schedule_id}", + json=params.model_dump(exclude_unset=True, exclude_none=True, mode="json"), + ) + handle_response(res) + rich.print(res.json()) + + +@app.command(help="Inspect a schedule", no_args_is_help=True) +def inspect( + schedule_id: str = typer.Argument(..., help="ID of the schedule to inspect"), +): + """Inspect a schedule""" + + with user_client_sync() as client: + res = client.get(f"/schedules/{schedule_id}") + handle_response(res) + rich.print(res.json()) diff --git a/tracecat/cli/workflow.py b/tracecat/cli/workflow.py index 128a25640..7102a5787 100644 --- a/tracecat/cli/workflow.py +++ b/tracecat/cli/workflow.py @@ -11,7 +11,7 @@ from tracecat.types.api import WebhookResponse from tracecat.types.headers import CustomHeaders -from ._utils import dynamic_table, user_client +from ._utils import dynamic_table, read_input, user_client app = typer.Typer(no_args_is_help=True, help="Manage workflows.") @@ -113,7 +113,7 @@ async def _list_workflows(): async with user_client() as client: res = await client.get("/workflows") res.raise_for_status() - return dynamic_table(res.json(), "Workflows") + return res.json() async def _get_cases(workflow_id: str): @@ -163,11 +163,17 @@ def commit( @app.command(name="list", help="List all workflow definitions") -def list_workflows(): +def list_workflows( + as_json: bool = typer.Option(False, "--json", help="Display as JSON"), +): """Commit a workflow definition to the database.""" rich.print("Listing all workflows") - table = asyncio.run(_list_workflows()) - Console().print(table) + result = asyncio.run(_list_workflows()) + if as_json: + rich.print(result) + else: + result = dynamic_table(result, "Workflows") + Console().print(result) @app.command(help="Run a workflow", no_args_is_help=True) @@ -185,23 +191,8 @@ def run( ): """Triggers a webhook to run a workflow.""" rich.print(f"Running workflow {workflow_id!r} {"proxied" if proxy else 'directly'}") - if data[0] == "@": - p = Path(data[1:]) - if not p.exists(): - raise typer.BadParameter(f"File {p} does not exist") - if p.suffix != ".json": - raise typer.BadParameter(f"File {p} is not a JSON file") - with p.open() as f: - data = f.read() - - asyncio.run( - _run_workflow( - workflow_id, - payload=orjson.loads(data) if data else None, - proxy=proxy, - test=test, - ) - ) + payload = read_input(data) if data else None + asyncio.run(_run_workflow(workflow_id, payload=payload, proxy=proxy, test=test)) @app.command(help="Activate a workflow", no_args_is_help=True) diff --git a/tracecat/config.py b/tracecat/config.py index 4d03986b7..8cb660db2 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -34,6 +34,9 @@ TEMPORAL__CLUSTER_NAMESPACE = os.environ.get( "TEMPORAL__CLUSTER_NAMESPACE", "default" ) # Temporal namespace +TEMPORAL__CLUSTER_QUEUE = os.environ.get( + "TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue" +) # Temporal task queue TEMPORAL__TLS_ENABLED = os.environ.get("TEMPORAL__TLS_ENABLED", False) TEMPORAL__TLS_ENABLED = os.environ.get("TEMPORAL__TLS_ENABLED", False) TEMPORAL__TLS_CLIENT_CERT = os.environ.get("TEMPORAL__TLS_CLIENT_CERT") diff --git a/tracecat/contexts.py b/tracecat/contexts.py index 37bbb388f..22c7800bf 100644 --- a/tracecat/contexts.py +++ b/tracecat/contexts.py @@ -13,9 +13,9 @@ class RunContext(BaseModel): - wf_id: identifiers.workflow.WorkflowID - wf_exec_id: identifiers.workflow.WorkflowExecutionID - wf_run_id: identifiers.workflow.WorkflowRunID + wf_id: identifiers.WorkflowID + wf_exec_id: identifiers.WorkflowExecutionID | identifiers.WorkflowScheduleID + wf_run_id: identifiers.WorkflowRunID ctx_run: ContextVar[RunContext] = ContextVar("run", default=None) diff --git a/tracecat/db/schemas.py b/tracecat/db/schemas.py index beb9c391b..ffd585484 100644 --- a/tracecat/db/schemas.py +++ b/tracecat/db/schemas.py @@ -1,6 +1,6 @@ """Database schemas for Tracecat.""" -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Self import pyarrow as pa @@ -313,12 +313,13 @@ class Schedule(Resource, table=True): id: str = Field( default_factory=id_factory("sch"), nullable=False, unique=True, index=True ) - status: str = "offline" # "online" or "offline" - cron: str - entrypoint_payload: dict[str, Any] = Field( - default_factory=dict, sa_column=Column(JSON) - ) - + status: str = "online" # "online" or "offline" + cron: str | None = None + inputs: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON)) + every: timedelta = Field(..., description="ISO 8601 duration string") + offset: timedelta | None = Field(None, description="ISO 8601 duration string") + start_at: datetime | None = Field(None, description="ISO 8601 datetime string") + end_at: datetime | None = Field(None, description="ISO 8601 datetime string") # Relationships workflow_id: str | None = Field( sa_column=Column(String, ForeignKey("workflow.id", ondelete="CASCADE")) diff --git a/tracecat/dsl/dispatcher.py b/tracecat/dsl/dispatcher.py index ceee447d1..4d260cc8f 100644 --- a/tracecat/dsl/dispatcher.py +++ b/tracecat/dsl/dispatcher.py @@ -1,13 +1,12 @@ import asyncio import json -import os import sys from typing import Any from loguru import logger from pydantic import BaseModel -from tracecat import identifiers +from tracecat import config, identifiers from tracecat.contexts import ctx_role from tracecat.dsl.common import DSLInput, get_temporal_client from tracecat.dsl.workflow import DSLContext, DSLRunArgs, DSLWorkflow @@ -33,7 +32,7 @@ async def dispatch_workflow( DSLWorkflow.run, DSLRunArgs(dsl=dsl, role=role, wf_id=wf_id), id=wf_exec_id, - task_queue=os.environ.get("TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue"), + task_queue=config.TEMPORAL__CLUSTER_QUEUE, **kwargs, ) logger.debug(f"Workflow result:\n{json.dumps(result, indent=2)}") diff --git a/tracecat/dsl/schedules.py b/tracecat/dsl/schedules.py new file mode 100644 index 000000000..8cf380ccc --- /dev/null +++ b/tracecat/dsl/schedules.py @@ -0,0 +1,167 @@ +import re +from datetime import datetime, timedelta +from typing import Any, TypeVar + +from pydantic import ValidationInfo, ValidatorFunctionWrapHandler, WrapValidator +from temporalio.client import ( + Schedule, + ScheduleActionStartWorkflow, + ScheduleHandle, + ScheduleIntervalSpec, + ScheduleSpec, + ScheduleUpdate, + ScheduleUpdateInput, +) + +from tracecat import config, identifiers +from tracecat.contexts import ctx_role +from tracecat.dsl.common import DSLInput, get_temporal_client +from tracecat.dsl.workflow import DSLRunArgs, DSLWorkflow +from tracecat.types.api import UpdateScheduleParams + +T = TypeVar("T") + +EASY_TD_PATTERN = ( + r"^" # Start of string + r"(?:(?P\d+)w)?" # Match weeks + r"(?:(?P\d+)d)?" # Match days + r"(?:(?P\d+)h)?" # Match hours + r"(?:(?P\d+)m)?" # Match minutes + r"(?:(?P\d+)s)?" # Match seconds + r"$" # End of string +) + + +class EasyTimedelta: + def __new__(cls): + return WrapValidator(cls.maybe_str2timedelta) + + @classmethod + def maybe_str2timedelta( + cls, v: T, handler: ValidatorFunctionWrapHandler, info: ValidationInfo + ) -> T: + if isinstance(v, str): + # If it's a string, try to parse it as a timedelta + try: + return string_to_timedelta(v) + except ValueError: + pass + # Otherwise, handle as normal + return handler(v, info) + + +def string_to_timedelta(time_str: str) -> timedelta: + # Regular expressions to match different time units with named capture groups + pattern = re.compile( + r"^" # Start of string + r"(?:(?P\d+)w)?" # Match weeks + r"(?:(?P\d+)d)?" # Match days + r"(?:(?P\d+)h)?" # Match hours + r"(?:(?P\d+)m)?" # Match minutes + r"(?:(?P\d+)s)?" # Match seconds + r"$" # End of string + ) + match = pattern.match(time_str) + + if not match: + raise ValueError("Invalid time format") + + # Extracting the values, defaulting to 0 if not present + weeks = int(match.group("weeks") or 0) + days = int(match.group("days") or 0) + hours = int(match.group("hours") or 0) + minutes = int(match.group("minutes") or 0) + seconds = int(match.group("seconds") or 0) + + # Check if all values are zero + if all(v == 0 for v in (weeks, days, hours, minutes, seconds)): + raise ValueError("Invalid time format. All values are zero.") + + # Creating a timedelta object + return timedelta( + days=days, weeks=weeks, hours=hours, minutes=minutes, seconds=seconds + ) + + +async def _get_handle(sch_id: identifiers.ScheduleID) -> ScheduleHandle: + client = await get_temporal_client() + return client.get_schedule_handle(sch_id) + + +async def create_schedule( + workflow_id: identifiers.WorkflowID, + schedule_id: identifiers.ScheduleID, + dsl: DSLInput, + # Schedule config + every: timedelta, + offset: timedelta | None = None, + start_at: datetime | None = None, + end_at: datetime | None = None, + **kwargs: Any, +) -> ScheduleHandle: + client = await get_temporal_client() + + workflow_schedule_id = f"{workflow_id}:{schedule_id}" + return await client.create_schedule( + schedule_id, + Schedule( + action=ScheduleActionStartWorkflow( + DSLWorkflow.run, + # The args that should run in the scheduled workflow + DSLRunArgs(dsl=dsl, role=ctx_role.get(), wf_id=workflow_id), + id=workflow_schedule_id, + task_queue=config.TEMPORAL__CLUSTER_QUEUE, + ), + spec=ScheduleSpec( + intervals=[ScheduleIntervalSpec(every=every, offset=offset)], + start_at=start_at, + end_at=end_at, + ), + ), + **kwargs, + ) + + +async def delete_schedule(schedule_id: identifiers.ScheduleID) -> ScheduleHandle: + handle = await _get_handle(schedule_id) + try: + return await handle.delete() + except Exception as e: + if "workflow execution already completed" in str(e): + # This is fine, we can ignore this error + return + raise e + + +async def update_schedule( + schedule_id: identifiers.ScheduleID, params: UpdateScheduleParams +) -> ScheduleUpdate: + async def _update_schedule(input: ScheduleUpdateInput) -> ScheduleUpdate: + set_fields = params.model_dump(exclude_unset=True) + action = input.description.schedule.action + spec = input.description.schedule.spec + state = input.description.schedule.state + + if "status" in set_fields: + state.paused = set_fields["status"] != "online" + if isinstance(action, ScheduleActionStartWorkflow): + if "inputs" in set_fields: + action.args[0].dsl.trigger_inputs = set_fields["inputs"] + else: + raise NotImplementedError( + "Only ScheduleActionStartWorkflow is supported for now." + ) + # We only support one interval per schedule for now + if "every" in set_fields: + spec.intervals[0].every = set_fields["every"] + if "offset" in set_fields: + spec.intervals[0].offset = set_fields["offset"] + if "start_at" in set_fields: + spec.start_at = set_fields["start_at"] + if "end_at" in set_fields: + spec.end_at = set_fields["end_at"] + + return ScheduleUpdate(schedule=input.description.schedule) + + handle = await _get_handle(schedule_id) + return await handle.update(_update_schedule) diff --git a/tracecat/dsl/workflow.py b/tracecat/dsl/workflow.py index 292c86526..f3a985405 100644 --- a/tracecat/dsl/workflow.py +++ b/tracecat/dsl/workflow.py @@ -428,7 +428,7 @@ async def run_udf(input: UDFActionInput) -> Any: else: result = await udf.run_async(args) - act_logger.info("Result", result=result) + act_logger.debug("Result", result=result) return result diff --git a/tracecat/identifiers/__init__.py b/tracecat/identifiers/__init__.py index 9cdce0335..0d7f08874 100644 --- a/tracecat/identifiers/__init__.py +++ b/tracecat/identifiers/__init__.py @@ -38,10 +38,16 @@ """ -from tracecat.identifiers import action, workflow +from tracecat.identifiers import action, schedules, workflow from tracecat.identifiers.action import ActionID, ActionKey, ActionRef from tracecat.identifiers.resource import id_factory -from tracecat.identifiers.workflow import WorkflowExecutionID, WorkflowID, WorkflowRunID +from tracecat.identifiers.schedules import ScheduleID +from tracecat.identifiers.workflow import ( + WorkflowExecutionID, + WorkflowID, + WorkflowRunID, + WorkflowScheduleID, +) __all__ = [ "ActionID", @@ -49,8 +55,11 @@ "ActionRef", "WorkflowID", "WorkflowExecutionID", + "WorkflowScheduleID", "WorkflowRunID", + "ScheduleID", "id_factory", "action", "workflow", + "schedules", ] diff --git a/tracecat/identifiers/schedules.py b/tracecat/identifiers/schedules.py new file mode 100644 index 000000000..022f36d08 --- /dev/null +++ b/tracecat/identifiers/schedules.py @@ -0,0 +1,15 @@ +"""Schedule identifiers.""" + +from typing import Annotated + +from pydantic import StringConstraints + +ScheduleID = Annotated[str, StringConstraints(pattern=r"sch-[0-9a-f]{32}")] +"""A unique ID for a schedule. + +This is the equivalent of a Schedule ID in Temporal. + +Exapmles +-------- +- 'sch-77932a0b140a4465a1a25a5c95edcfb8' +""" diff --git a/tracecat/identifiers/workflow.py b/tracecat/identifiers/workflow.py index 1ff18c03c..be23c3968 100644 --- a/tracecat/identifiers/workflow.py +++ b/tracecat/identifiers/workflow.py @@ -6,6 +6,11 @@ from tracecat.identifiers.resource import ResourcePrefix, generate_resource_id +WorkflowScheduleID = Annotated[ + str, StringConstraints(pattern=r"wf-[0-9a-f]{32}:sch-[0-9a-f]{32}") +] +"""A unique ID for a scheduled workflow.""" + WorkflowID = Annotated[str, StringConstraints(pattern=r"wf-[0-9a-f]{32}")] """A unique ID for a workflow. diff --git a/tracecat/types/api.py b/tracecat/types/api.py index 5b4289d3e..8fa3d1191 100644 --- a/tracecat/types/api.py +++ b/tracecat/types/api.py @@ -1,11 +1,12 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator +from tracecat import identifiers from tracecat.db.schemas import ActionRun, Resource, Schedule, WorkflowRun from tracecat.dsl.common import DSLInput from tracecat.types.generics import ListModel @@ -334,12 +335,6 @@ class UDFArgsValidationResponse(BaseModel): detail: Any | None = None -class CreateScheduleParams(BaseModel): - entrypoint_ref: str - entrypoint_payload: dict[str, Any] | None = None - cron: str - - class CommitWorkflowResponse(BaseModel): workflow_id: str status: Literal["success", "failure"] @@ -352,3 +347,33 @@ class ServiceCallbackAction(BaseModel): action: Literal["webhook"] payload: dict[str, Any] metadata: dict[str, Any] + + +class CreateScheduleParams(BaseModel): + workflow_id: identifiers.WorkflowID + inputs: dict[str, Any] | None = None + cron: str | None = None + every: timedelta = Field(..., description="ISO 8601 duration string") + offset: timedelta | None = Field(None, description="ISO 8601 duration string") + start_at: datetime | None = Field(None, description="ISO 8601 datetime string") + end_at: datetime | None = Field(None, description="ISO 8601 datetime string") + status: Literal["online", "offline"] = "online" + + +class UpdateScheduleParams(BaseModel): + inputs: dict[str, Any] | None = None + cron: str | None = None + every: timedelta | None = Field(None, description="ISO 8601 duration string") + offset: timedelta | None = Field(None, description="ISO 8601 duration string") + start_at: datetime | None = Field(None, description="ISO 8601 datetime string") + end_at: datetime | None = Field(None, description="ISO 8601 datetime string") + status: Literal["online", "offline"] | None = None + + +class SearchScheduleParams(BaseModel): + workflow_id: str | None = None + limit: int = 100 + order_by: str = "created_at" + query: str | None = None + group_by: list[str] | None = None + agg: str | None = None