Skip to content

Commit

Permalink
Merge pull request #36 from PySpur-Dev/feat/subworkflows
Browse files Browse the repository at this point in the history
Feat/subworkflows
  • Loading branch information
srijanpatel authored Dec 3, 2024
2 parents f65a523 + 0467f87 commit 2069a7d
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 77 deletions.
54 changes: 5 additions & 49 deletions backend/app/api/run_management.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import json
from typing import Any, Dict, List, Optional
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from trio import TaskStatus

from ..schemas.run_schemas import RunResponseSchema, RunStatusResponseSchema
from ..schemas.run_schemas import RunResponseSchema
from ..database import get_db
from ..models.run_model import RunModel, RunStatus
from ..models.task_model import TaskStatus
from ..models.run_model import RunModel

router = APIRouter()

Expand All @@ -33,50 +30,9 @@ def list_runs(
return runs


@router.get("/{run_id}/status/", response_model=RunStatusResponseSchema)
@router.get("/{run_id}/status/", response_model=RunResponseSchema)
def get_run_status(run_id: str, db: Session = Depends(get_db)):
run = db.query(RunModel).filter(RunModel.id == run_id).first()
if not run:
raise HTTPException(status_code=404, detail="Run not found")
output_file_id = None
if run.status == RunStatus.COMPLETED:
# find output file id
output_file = run.output_file
if output_file:
output_file_id = output_file.id

tasks = run.tasks
tasks_meta = [
{
"node_id": task.node_id,
"status": task.status,
"inputs": task.inputs,
"outputs": task.outputs,
"run_time": task.run_time,
"subworkflow": json.loads(task.subworkflow) if task.subworkflow else None,
"subworkflow_output": (
json.loads(task.subworkflow_output) if task.subworkflow_output else None
),
}
for task in tasks
]
# fail if any task has failed
for task in tasks:
if task.status == TaskStatus.FAILED:
# update run status to failed
run.status = RunStatus.FAILED
db.commit()
break

combined_task_outputs: Dict[str, Any] = {}
for task in tasks:
combined_task_outputs[task.node_id] = task.outputs
return RunStatusResponseSchema(
id=run.id,
status=run.status,
start_time=run.start_time,
end_time=run.end_time,
outputs=combined_task_outputs,
tasks=tasks_meta,
output_file_id=output_file_id,
)
return run
1 change: 0 additions & 1 deletion backend/app/api/workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ..execution.workflow_executor import WorkflowExecutor
from ..dataset.ds_util import get_ds_iterator, get_ds_column_names
from ..execution.task_recorder import TaskRecorder
from ..evals.evaluator import prepare_and_evaluate_dataset, load_yaml_config
from ..utils.workflow_version_utils import fetch_workflow_version
from ..execution.workflow_execution_context import WorkflowExecutionContext

Expand Down
16 changes: 5 additions & 11 deletions backend/app/execution/task_recorder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
import json
from pydantic import BaseModel
from ..schemas.workflow_schemas import WorkflowDefinitionSchema
from ..models.task_model import TaskModel, TaskStatus
Expand All @@ -17,17 +16,11 @@ def create_task(
self,
node_id: str,
inputs: Dict[str, Any],
subworkflow: Optional[WorkflowDefinitionSchema] = None,
):
if subworkflow:
subworkflow_val = json.dumps(subworkflow.model_dump())
else:
subworkflow_val = None
task = TaskModel(
run_id=self.run_id,
node_id=node_id,
inputs=inputs,
subworkflow=subworkflow_val,
)
self.db.add(task)
self.db.commit()
Expand Down Expand Up @@ -57,10 +50,11 @@ def update_task(
if end_time:
task.end_time = end_time
if subworkflow:
task.subworkflow = json.dumps(subworkflow.model_dump())
task.subworkflow = subworkflow.model_dump()
if subworkflow_output:
task.subworkflow_output = json.dumps(
{k: v.model_dump() for k, v in subworkflow_output.items()}
)
task.subworkflow_output = {
k: v.model_dump() for k, v in subworkflow_output.items()
}
self.db.add(task)
self.db.commit()
return
8 changes: 7 additions & 1 deletion backend/app/execution/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def __init__(
context: Optional[WorkflowExecutionContext] = None,
):
self.workflow = workflow
self.task_recorder = task_recorder
if task_recorder:
self.task_recorder = task_recorder
elif context and context.run_id and context.db_session:
print("Creating task recorder from context")
self.task_recorder = TaskRecorder(context.db_session, context.run_id)
else:
self.task_recorder = None
self.context = context
self._node_dict: Dict[str, WorkflowNodeSchema] = {}
self._dependencies: Dict[str, Set[str]] = {}
Expand Down
4 changes: 2 additions & 2 deletions backend/app/models/task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class TaskModel(BaseModel):
DateTime, default=datetime.now(timezone.utc)
)
end_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
subworkflow: Mapped[Optional[str]] = mapped_column(String, nullable=True)
subworkflow_output: Mapped[Optional[str]] = mapped_column(String, nullable=True)
subworkflow: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
subworkflow_output: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)

parent_task: Mapped[Optional["TaskModel"]] = relationship(
"TaskModel", remote_side=[id], backref="subtasks"
Expand Down
2 changes: 1 addition & 1 deletion backend/app/models/workflow_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy import Computed, Integer, String, DateTime, JSON
from sqlalchemy.orm import Mapped, mapped_column, relationship
from datetime import datetime, timezone
from typing import List, Optional, Any
from typing import Optional, Any
from .base_model import BaseModel


Expand Down
14 changes: 3 additions & 11 deletions backend/app/schemas/run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from pydantic import BaseModel
from datetime import datetime

from ..schemas.workflow_schemas import WorkflowVersionResponseSchema
from .workflow_schemas import WorkflowVersionResponseSchema
from ..models.run_model import RunStatus
from .task_schemas import TaskResponseSchema


class StartRunRequestSchema(BaseModel):
Expand All @@ -24,6 +25,7 @@ class RunResponseSchema(BaseModel):
output_file_id: Optional[str]
start_time: Optional[datetime]
end_time: Optional[datetime]
tasks: List[TaskResponseSchema]

class Config:
from_attributes = True
Expand All @@ -36,16 +38,6 @@ class PartialRunRequestSchema(BaseModel):
partial_outputs: Optional[Dict[str, Dict[str, Any]]] = None


class RunStatusResponseSchema(BaseModel):
id: str
status: RunStatus
start_time: Optional[datetime]
end_time: Optional[datetime]
tasks: List[Dict[str, Any]]
outputs: Optional[Dict[str, Any]]
output_file_id: Optional[str]


class BatchRunRequestSchema(BaseModel):
dataset_id: str
mini_batch_size: int = 10
22 changes: 22 additions & 0 deletions backend/app/schemas/task_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Dict, Optional, Any
from pydantic import BaseModel
from datetime import datetime
from ..models.task_model import TaskStatus
from .workflow_schemas import WorkflowDefinitionSchema


class TaskResponseSchema(BaseModel):
id: str
run_id: str
node_id: str
parent_task_id: Optional[str]
status: TaskStatus
inputs: Optional[Any]
outputs: Optional[Any]
start_time: Optional[datetime]
end_time: Optional[datetime]
subworkflow: Optional[WorkflowDefinitionSchema]
subworkflow_output: Optional[Dict[str, Any]]

class Config:
from_attributes = True # Enable ORM mode
1 change: 1 addition & 0 deletions backend/app/schemas/workflow_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class WorkflowNodeSchema(BaseModel):
coordinates: Optional[WorkflowNodeCoordinatesSchema] = (
None # Position of the node in the workflow
)
subworkflow: Optional["WorkflowDefinitionSchema"] = None # Sub-workflow definition

@field_validator("node_type")
def type_must_be_in_factory(cls, v: str):
Expand Down
2 changes: 1 addition & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ asyncio
numpy<2.0.0
tiktoken==0.7.0
python-dotenv==1.0.1
openai==1.47.1
openai==1.55.3
scikit-learn==1.5.2
tenacity>=8.2.0,<8.4.0
black==24.8.0
Expand Down

0 comments on commit 2069a7d

Please sign in to comment.