Skip to content

Commit

Permalink
Signals and ATP v2 (#98)
Browse files Browse the repository at this point in the history
* Progress towards signals in SDK

This should be backwards-compatible

* Added signals to ATP server

Also moved ATP server into a class

* Add ATP v2 support to client

Also fixed a few enum issues

* Fix linting errors

* Fix linting error

* Ignore linting issue

* Fix linting error

* Fix errors introduced while trying to fix linting errors

* Fix missing params in test

* Update dependencies

* Update pyyaml

* Fixed typo

* Fix predefined schema

* Bypass class limitations by passing method names instead

* Delay signal retrieval

* Added missing field, and added extra None check

* Fix parameter passed into test case

* Fix missing input to server read loop

* Fix wrong signal ID

Also refactored to improve the code

* Fix extra parameter passed into function

* Fix missing parameter passed into function

* Fix missing deserialization step

* Fix missed function rename, and fix linting err

* Reduce redundancy, and fix wrong var passed

* Added client done message, and signal atp tests

* Fix linting errors

* Remove join and add flush

Also did a mild refactor of function

* Change when read thread is launched

* Ignore sigint and manage stdout correctly

* Fix linting errors, and added comments

* Fix ordering problem with steps and signals

* Added coverage config file

* Remove unused import

* Removed print statements, and added fail when debug logs aren't empty

* Remove coverage config

This appears to be causing a failure without any reason why
  • Loading branch information
jaredoconnell authored Sep 21, 2023
1 parent a730428 commit 44b451e
Show file tree
Hide file tree
Showing 9 changed files with 978 additions and 385 deletions.
588 changes: 329 additions & 259 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ homepage = "https://github.com/arcalot/arcaflow-plugin-sdk-python"
[tool.poetry.dependencies]
python = "^3.9"
cbor2 = "^5.4.3"
PyYAML = "^5.4"
PyYAML = "^6.0.1"

[tool.poetry.group.dev.dependencies]
coverage = "^6.5.0"
Expand Down
273 changes: 196 additions & 77 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,29 @@
import dataclasses
import io
import os
import signal
import sys
import typing
import threading
import signal

import cbor2

from enum import Enum

from arcaflow_plugin_sdk import schema


class MessageType(Enum):
"""
An integer ID that indicates the type of runtime message that is stored
in the data field. The corresponding class can then be used to deserialize
the inner data. Look at the go SDK for the reference data structure.
"""
WORK_DONE = 1
SIGNAL = 2
CLIENT_DONE = 3


@dataclasses.dataclass
class HelloMessage:
"""
Expand All @@ -50,73 +64,143 @@ class HelloMessage:
_HELLO_MESSAGE_SCHEMA = schema.build_object_schema(HelloMessage)


def _handle_exit(_signo, _stack_frame):
print("Exiting normally")
sys.exit(0)
def signal_handler(_sig, _frame):
pass # Do nothing


def run_plugin(
s: schema.SchemaType,
stdin: io.FileIO,
stdout: io.FileIO,
stderr: io.FileIO,
) -> int:
"""
This function wraps running a plugin.
"""
if os.isatty(stdout.fileno()):
print("Cannot run plugin in ATP mode on an interactive terminal.")
return 1
class ATPServer:
stdin: io.FileIO
stdout: io.FileIO
stderr: io.FileIO
step_object: typing.Any

signal.signal(signal.SIGTERM, _handle_exit)
try:
decoder = cbor2.decoder.CBORDecoder(stdin)
encoder = cbor2.encoder.CBOREncoder(stdout)
def __init__(
self,
stdin: io.FileIO,
stdout: io.FileIO,
stderr: io.FileIO,
) -> None:
self.stdin = stdin
self.stdout = stdout
self.stderr = stderr

# Decode empty "start output" message.
decoder.decode()
def run_plugin(
self,
plugin_schema: schema.SchemaType,
) -> int:
"""
This function wraps running a plugin.
"""
signal.signal(signal.SIGINT, signal_handler) # Ignore sigint. Only care about arcaflow signals.
if os.isatty(self.stdout.fileno()):
print("Cannot run plugin in ATP mode on an interactive terminal.")
return 1
try:
decoder = cbor2.decoder.CBORDecoder(self.stdin)
encoder = cbor2.encoder.CBOREncoder(self.stdout)

start = HelloMessage(1, s)
serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start)
encoder.encode(serialized_message)
stdout.flush()
# Decode empty "start output" message.
decoder.decode()

message = decoder.decode()
except SystemExit:
return 0
try:
if message is None:
stderr.write("Work start message is None.")
return 1
if message["id"] is None:
stderr.write("Work start message is missing the 'id' field.")
return 1
if message["config"] is None:
stderr.write("Work start message is missing the 'config' field.")
# Serialize then send HelloMessage
start_hello_message = HelloMessage(2, plugin_schema)
serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start_hello_message)
encoder.encode(serialized_message)
self.stdout.flush()

# Can fail here if only getting schema.
work_start_msg = decoder.decode()
except SystemExit:
return 0
try:
if work_start_msg is None:
self.stderr.write("Work start message is None.")
return 1
if work_start_msg["id"] is None:
self.stderr.write("Work start message is missing the 'id' field.")
return 1
if work_start_msg["config"] is None:
self.stderr.write("Work start message is missing the 'config' field.")
return 1

# Run the step
original_stdout = sys.stdout
original_stderr = sys.stderr
out_buffer = io.StringIO()
sys.stdout = out_buffer
sys.stderr = out_buffer
# Run the read loop
read_thread = threading.Thread(target=self.run_server_read_loop, args=(
plugin_schema, # Plugin schema
work_start_msg["id"], # step ID
decoder, # Decoder
))
read_thread.start()
output_id, output_data = plugin_schema.call_step(
work_start_msg["id"],
plugin_schema.unserialize_step_input(work_start_msg["id"], work_start_msg["config"])
)

# Send WorkDoneMessage in a RuntimeMessage
encoder.encode(
{
"id": MessageType.WORK_DONE.value,
"data": {
"output_id": output_id,
"output_data": plugin_schema.serialize_output(
work_start_msg["id"], output_id, output_data
),
"debug_logs": out_buffer.getvalue(),
}
}
)
self.stdout.flush() # Sends it to the ATP client immediately. Needed so it can realize it's done.
read_thread.join() # Wait for the read thread to finish.
# Don't reset stdout/stderr until after the read thread is done.
sys.stdout = original_stdout
sys.stderr = original_stderr
except SystemExit:
return 1
original_stdout = sys.stdout
original_stderr = sys.stderr
out_buffer = io.StringIO()
sys.stdout = out_buffer
sys.stderr = out_buffer
output_id, output_data = s.call_step(
message["id"], s.unserialize_input(message["id"], message["config"])
)
sys.stdout = original_stdout
sys.stderr = original_stderr
encoder.encode(
{
"output_id": output_id,
"output_data": s.serialize_output(
message["id"], output_id, output_data
),
"debug_logs": out_buffer.getvalue(),
}
)
stdout.flush()
except SystemExit:
return 1
return 0
return 0

def run_server_read_loop(
self,
plugin_schema: schema.SchemaType,
step_id: str,
decoder: cbor2.decoder.CBORDecoder,
) -> None:
try:
while True:
# Decode the message
runtime_msg = decoder.decode()
msg_id = runtime_msg["id"]
# Validate
if msg_id is None:
self.stderr.write("Runtime message is missing the 'id' field.")
return
# Then take action
if msg_id == MessageType.SIGNAL.value:
signal_msg = runtime_msg["data"]
received_step_id = signal_msg["step_id"]
received_signal_id = signal_msg["signal_id"]
if received_step_id != step_id: # Ensure they match.
self.stderr.write(f"Received step ID in the signal message '{received_step_id}'"
f"does not match expected step ID '{step_id}'")
return
unserialized_data = plugin_schema.unserialize_signal_handler_input(
received_step_id,
received_signal_id,
signal_msg["data"]
)
# The data is verified and unserialized. Now call the signal.
plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data)
elif msg_id == MessageType.CLIENT_DONE.value:
return
else:
self.stderr.write(f"Unknown kind of runtime message: {msg_id}")

except cbor2.CBORDecodeError as err:
self.stderr.write(f"Error while decoding CBOR: {err}")


class PluginClientStateException(Exception):
Expand Down Expand Up @@ -158,7 +242,7 @@ def start_output(self) -> None:

def read_hello(self) -> HelloMessage:
"""
This function reads the intial "Hello" message from the plugin.
This function reads the initial "Hello" message from the plugin.
"""
message = self.decoder.decode()
return _HELLO_MESSAGE_SCHEMA.unserialize(message)
Expand All @@ -173,22 +257,57 @@ def start_work(self, step_id: str, input_data: any):
"config": input_data,
}
)
self.stdin.flush()

def send_signal(self, step_id: str, signal_id: str, input_data: any):
"""
This function sends any signals to the plugin.
"""
self.send_runtime_message(MessageType.SIGNAL, {
"step_id": step_id,
"signal_id": signal_id,
"data": input_data,
}
)

def send_client_done(self):
self.send_runtime_message(MessageType.CLIENT_DONE, {})

def send_runtime_message(self, message_type: MessageType, data: any):
self.encoder.encode(
{
"id": message_type.value,
"data": data,
}
)
self.stdin.flush()

def read_results(self) -> (str, any, str):
"""
This function reads the results of an execution from the plugin.
This function reads the signals and results of an execution from the plugin.
"""
message = self.decoder.decode()
if message["output_id"] is None:
raise PluginClientStateException(
"Missing 'output_id' in CBOR message. Possibly wrong order of calls?"
)
if message["output_data"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
if message["debug_logs"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
return message["output_id"], message["output_data"], message["debug_logs"]
while True:
runtime_msg = self.decoder.decode()
msg_id = runtime_msg["id"]
if msg_id == MessageType.WORK_DONE.value:
signal_msg = runtime_msg["data"]
if signal_msg["output_id"] is None:
raise PluginClientStateException(
"Missing 'output_id' in CBOR message. Possibly wrong order of calls?"
)
if signal_msg["output_data"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
if signal_msg["debug_logs"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
return signal_msg["output_id"], signal_msg["output_data"], signal_msg["debug_logs"]
elif msg_id == MessageType.SIGNAL.value:
# Do nothing. Should change in the future.
continue
else:
raise PluginClientStateException(
f"Received unknown runtime message ID {msg_id}"
)
Loading

0 comments on commit 44b451e

Please sign in to comment.