Skip to content

Commit

Permalink
add new cli commands
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Jan 13, 2025
1 parent 00c8bb2 commit 482f25f
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 25 deletions.
299 changes: 298 additions & 1 deletion src/prefect/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Command line interface for working with Prefect
"""

import asyncio
import inspect
import os
import shlex
import signal
Expand All @@ -10,11 +12,13 @@
import sys
import textwrap
from pathlib import Path
from typing import Any

import anyio
import anyio.abc
import typer
import uvicorn
from rich.table import Table

import prefect
import prefect.settings
Expand Down Expand Up @@ -49,15 +53,17 @@
help="Start a Prefect server instance and interact with the database",
)
database_app = PrefectTyper(name="database", help="Interact with the database.")
services_app = PrefectTyper(name="services", help="Interact with the services.")
server_app.add_typer(database_app)
server_app.add_typer(services_app)
app.add_typer(server_app)

logger = get_logger(__name__)

PID_FILE = "server.pid"


def generate_welcome_blurb(base_url, ui_enabled: bool):
def generate_welcome_blurb(base_url: str, ui_enabled: bool):
blurb = textwrap.dedent(
r"""
___ ___ ___ ___ ___ ___ _____
Expand Down Expand Up @@ -488,3 +494,294 @@ async def stamp(revision: str):
app.console.print("Stamping database with revision ...")
await run_sync_in_worker_thread(alembic_stamp, revision=revision)
exit_with_success("Stamping database with revision succeeded!")


def _discover_services() -> (
tuple[
list[type[prefect.server.services.loop_service.LoopService]],
dict[str, prefect.settings.Setting],
]
):
"""Discover all available services and their settings"""

from prefect.server.events.services import triggers
from prefect.server.services import (
cancellation_cleanup,
flow_run_notifications,
foreman,
late_runs,
loop_service,
pause_expirations,
scheduler,
task_run_recorder,
telemetry,
)

# Map of service names to their settings
service_settings = {
"Telemetry": prefect.settings.PREFECT_SERVER_ANALYTICS_ENABLED,
"TaskRunRecorder": prefect.settings.PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED,
"EventPersister": prefect.settings.PREFECT_API_SERVICES_EVENT_PERSISTER_ENABLED,
"Distributor": prefect.settings.PREFECT_API_EVENTS_STREAM_OUT_ENABLED,
"Scheduler": prefect.settings.PREFECT_API_SERVICES_SCHEDULER_ENABLED,
"RecentDeploymentsScheduler": prefect.settings.PREFECT_API_SERVICES_SCHEDULER_ENABLED,
"MarkLateRuns": prefect.settings.PREFECT_API_SERVICES_LATE_RUNS_ENABLED,
"FailExpiredPauses": prefect.settings.PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_ENABLED,
"CancellationCleanup": prefect.settings.PREFECT_API_SERVICES_CANCELLATION_CLEANUP_ENABLED,
"FlowRunNotifications": prefect.settings.PREFECT_API_SERVICES_FLOW_RUN_NOTIFICATIONS_ENABLED,
"Foreman": prefect.settings.PREFECT_API_SERVICES_FOREMAN_ENABLED,
"ReactiveTriggers": prefect.settings.PREFECT_API_SERVICES_TRIGGERS_ENABLED,
"ProactiveTriggers": prefect.settings.PREFECT_API_SERVICES_TRIGGERS_ENABLED,
"Actions": prefect.settings.PREFECT_API_SERVICES_TRIGGERS_ENABLED,
}

# Find all service classes by inspecting modules
service_modules = [
cancellation_cleanup,
flow_run_notifications,
foreman,
late_runs,
pause_expirations,
scheduler,
task_run_recorder,
telemetry,
triggers,
]

discovered_services: list[type[loop_service.LoopService]] = []
for module in service_modules:
for _, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, loop_service.LoopService)
and obj != loop_service.LoopService
):
discovered_services.append(obj)

return discovered_services, service_settings


def _get_service_map(
discovered_services: list[type[prefect.server.services.loop_service.LoopService]],
service_settings: dict[str, prefect.settings.Setting],
) -> dict[str, Any]:
"""Create a map of service names to their classes and settings"""
return {
service_class.__name__: (
service_class,
service_settings.get(service_class.__name__, False),
)
for service_class in discovered_services
}


@services_app.command(aliases=["list"])
async def list_services():
"""List all services"""
import inspect

discovered_services, service_settings = _discover_services()
service_map = _get_service_map(discovered_services, service_settings)

# Get currently running services
running_services = _check_for_running_services()

table = Table(
title="Available Services",
expand=True,
)
table.add_column("Name", style="blue", no_wrap=True)
table.add_column("Status", style="green", no_wrap=True)
table.add_column("Description", style="cyan", no_wrap=False)

for name, (service_class, setting) in sorted(service_map.items()):
enabled = setting.value()
running = name.lower() in running_services

if running:
status = "Running"
status_style = "green"
elif enabled:
status = "Enabled"
status_style = "yellow"
else:
status = "Disabled"
status_style = "red"

description = ""
if doc := inspect.getdoc(service_class):
description = doc.split("\n")[0].strip()
if len(description) > 60:
description = description[:57] + "..."

table.add_row(name, status, description, style=status_style)

app.console.print(table)


@services_app.command(aliases=["stop"])
async def stop_services():
"""Stop all background services"""
pid_dir = Path(PREFECT_HOME.value() / "services")
if not pid_dir.exists():
exit_with_success("No services are running in the background.")

if not (pid_files := list(pid_dir.glob("*.pid"))):
exit_with_success("No services are running in the background.")

app.console.print("\n[yellow]Shutting down...[/]")
for pid_file in pid_files:
service_name = pid_file.stem.title() # Display in title case
try:
pid = int(pid_file.read_text())
try:
os.kill(pid, signal.SIGTERM)
app.console.print(f"[dim]✓ {service_name}[/]")
except ProcessLookupError:
app.console.print(
f"[yellow]Process for {service_name} was not running[/]"
)
except (ValueError, OSError) as e:
app.console.print(f"[red]✗ {service_name}: {str(e)}[/]")
finally:
pid_file.unlink(missing_ok=True)

try:
pid_dir.rmdir()
except OSError:
pass

app.console.print("\n[green]All services stopped.[/]")


def _check_for_running_services() -> list[str]:
"""Check for any running services and return their names. Also cleans up stale PID files."""
pid_dir = Path(PREFECT_HOME.value() / "services")
if not pid_dir.exists():
return []

running_services: list[str] = []
for pid_file in pid_dir.glob("*.pid"):
try:
pid = int(pid_file.read_text())
os.kill(pid, 0)
# Store lowercase for matching, but title case for display
running_services.append(pid_file.stem)
except (ProcessLookupError, ValueError, OSError):
# Process doesn't exist or invalid PID, clean up stale file
pid_file.unlink(missing_ok=True)

return running_services


@services_app.command(aliases=["start"])
async def start_services(
background: bool = typer.Option(
False, "--background", "-b", help="Run the services in the background"
),
):
"""Start all enabled Prefect services"""
if running_services := _check_for_running_services():
app.console.print(
"\n[yellow]Services are already running in the background:[/]"
)
for service in running_services:
app.console.print(f"[dim]• {service}[/]")
app.console.print(
"\n[blue]Use[/] [yellow]`prefect server services stop`[/] [blue]to stop them first.[/]"
)
return

discovered_services, service_settings = _discover_services()
service_map = _get_service_map(discovered_services, service_settings)

service_instances: list[prefect.server.services.loop_service.LoopService] = []
for _, (service_class, setting) in service_map.items():
if setting.value():
service_instances.append(service_class())

if not service_instances:
exit_with_error("No services are enabled!")

if background:
pid_dir = Path(PREFECT_HOME.value() / "services")
pid_dir.mkdir(parents=True, exist_ok=True)

processes: list[subprocess.Popen[Any]] = []
for service in service_instances:
module_name = service.__class__.__module__.split(".")[-1]
command = [
sys.executable,
"-m",
f"prefect.server.services.{module_name}",
]

process = subprocess.Popen(
command,
env={**os.environ},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
processes.append(process)

pid_file = pid_dir / f"{service.name.lower()}.pid"
pid_file.write_text(str(process.pid))

try:
if process.poll() is not None:
app.console.print(f"[red]✗ {service.name}: Failed to start[/]")
stderr = process.stderr.read().decode() if process.stderr else ""
if stderr:
app.console.print(f"[red]{stderr}[/]")
continue
except Exception as e:
app.console.print(f"[red]✗ {service.name}: {str(e)}[/]")
continue

app.console.print(f"[dim]✓ {service.name}[/]")

app.console.print(
"\n[green]Services are running in the background.[/]"
"\n[blue]Use[/] [yellow]`prefect server services list`[/] [blue]to check their status.[/]"
"\n[blue]Use[/] [yellow]`prefect server services stop`[/] [blue]to stop them.[/]"
)
else:
app.console.print("\n[blue]Starting services... Press CTRL+C to stop[/]\n")

service_tasks: list[
tuple[asyncio.Task[None], prefect.server.services.loop_service.LoopService]
] = []

for service in service_instances:
task = asyncio.create_task(service.start())
service_tasks.append((task, service))
app.console.print(f"[dim]✓ {service.name}[/]")

app.console.print() # Add blank line after startup
shutdown_event = asyncio.Event()

def handle_signal(signum: int, frame: Any):
app.console.print("\n[yellow]Shutting down...[/]")
for task, _ in service_tasks:
task.cancel()
asyncio.get_event_loop().call_soon_threadsafe(shutdown_event.set)

signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)

try:
await asyncio.gather(*(task for task, _ in service_tasks))
except asyncio.CancelledError:
await shutdown_event.wait()
results = await asyncio.gather(
*(task for task, _ in service_tasks), return_exceptions=True
)
for (_, service), result in zip(service_tasks, results):
if isinstance(result, asyncio.CancelledError):
app.console.print(f"[dim]Stopped {service.name}[/]")
elif isinstance(result, Exception):
app.console.print(f"[red]Failed {service.name}: {result}[/]")
else:
app.console.print(f"[dim]Stopped {service.name}[/]")
finally:
app.console.print("\n[green]All services stopped.[/]")
2 changes: 2 additions & 0 deletions src/prefect/server/events/services/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ async def stop(self) -> None:


class ProactiveTriggers(LoopService):
"""A loop service that runs the proactive triggers consumer"""

def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any):
super().__init__(
loop_seconds=(
Expand Down
Loading

0 comments on commit 482f25f

Please sign in to comment.