Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decorator bug fixes #29

Merged
merged 1 commit into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 52 additions & 31 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,19 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):

def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)
self.config = config

if not isinstance(self.num_cpus, (int, float)):
Expand All @@ -55,6 +62,12 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
num_cpus=self.num_cpus,
num_gpus=self.num_gpus,
memory=self.memory,
resources=self.ray_resources,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=self.job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
)

Expand All @@ -66,19 +79,25 @@ def execute(self, context: Context) -> Any:
:return: The result of the Ray job execution.
:raises AirflowException: If job submission fails.
"""
tmp_dir = mkdtemp(prefix="ray_")
tmp_dir = None
try:
py_source = self.get_python_source().splitlines() # type: ignore
function_body = textwrap.dedent("\n".join(py_source[1:]))
if self.is_decorated_function:
self.log.info(
f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}"
)
tmp_dir = mkdtemp(prefix="ray_")
py_source = self.get_python_source().splitlines() # type: ignore
function_body = textwrap.dedent("\n".join(py_source))

script_filename = os.path.join(tmp_dir, "script.py")
with open(script_filename, "w") as file:
all_args_str = self._build_args_str()
script_body = f"{function_body}\n{self._extract_function_name()}({all_args_str})"
file.write(script_body)

self.entrypoint = "python script.py"
self.runtime_env["working_dir"] = tmp_dir

script_filename = os.path.join(tmp_dir, "script.py")
with open(script_filename, "w") as file:
all_args_str = self._build_args_str()
script_body = f"{function_body}\n{self._extract_function_name()}({all_args_str})"
file.write(script_body)

self.entrypoint = "python script.py"
self.runtime_env["working_dir"] = tmp_dir
self.log.info("Running ray job...")

result = super().execute(context)
Expand All @@ -87,7 +106,7 @@ def execute(self, context: Context) -> Any:
self.log.error(f"Failed during execution with error: {e}")
raise AirflowException("Job submission failed") from e
finally:
if os.path.exists(tmp_dir):
if tmp_dir and os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)

def _build_args_str(self) -> str:
Expand All @@ -109,22 +128,24 @@ def _extract_function_name(self) -> str:
return self.python_callable.__name__


def ray_task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.
class task:
@staticmethod
def ray(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.

:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
**kwargs,
)
:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
**kwargs,
)
146 changes: 123 additions & 23 deletions tests/decorators/test_ray_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest
from airflow.decorators.base import _TaskDecorator
from airflow.exceptions import AirflowException
from airflow.utils.context import Context

from ray_provider.decorators.ray import _RayDecoratedOperator, ray_task
from ray_provider.decorators.ray import _RayDecoratedOperator, task
from ray_provider.operators.ray import SubmitRayJob

DEFAULT_DATE = "2023-01-01"
Expand All @@ -15,30 +16,79 @@ class TestRayDecoratedOperator:

def test_initialization(self):
config = {
"host": "http://localhost:8265",
"conn_id": "ray_default",
"entrypoint": "python my_script.py",
"runtime_env": {"pip": ["ray"]},
"num_cpus": 2,
"num_gpus": 1,
"memory": "1G",
"memory": 1024,
"resources": {"custom_resource": 1},
"fetch_logs": True,
"wait_for_completion": True,
"job_timeout_seconds": 300,
"poll_interval": 30,
"xcom_task_key": "ray_result",
}

def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

assert operator.conn_id == "ray_default"
assert operator.entrypoint == "python my_script.py"
assert operator.runtime_env == {"pip": ["ray"]}
assert operator.num_cpus == 2
assert operator.num_gpus == 1
assert operator.memory == "1G"
assert operator.memory == 1024
assert operator.ray_resources == {"custom_resource": 1}
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 300
assert operator.poll_interval == 30
assert operator.xcom_task_key == "ray_result"

def test_initialization_defaults(self):
config = {}

def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

assert operator.conn_id == ""
assert operator.entrypoint == "python script.py"
assert operator.runtime_env == {}
assert operator.num_cpus == 1
assert operator.num_gpus == 0
assert operator.memory is None
assert operator.resources is None
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 600
assert operator.poll_interval == 60
assert operator.xcom_task_key is None

def test_invalid_config_raises_exception(self):
config = {
"num_cpus": "invalid_number",
}

def dummy_callable():
pass

with pytest.raises(TypeError):
_RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

config["num_cpus"] = 1
config["num_gpus"] = "invalid_number"
with pytest.raises(TypeError):
_RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

@patch.object(_RayDecoratedOperator, "get_python_source")
@patch.object(SubmitRayJob, "execute")
def test_execute(self, mock_super_execute, mock_get_python_source):
def test_execute_decorated_function(self, mock_super_execute, mock_get_python_source):
config = {
"entrypoint": "python my_script.py",
"runtime_env": {"pip": ["ray"]},
}

Expand All @@ -47,52 +97,102 @@ def dummy_callable():

context = MagicMock(spec=Context)
operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

mock_get_python_source.return_value = "def my_function():\n pass\n"
mock_get_python_source.return_value = "def dummy_callable():\n pass\n"
mock_super_execute.return_value = "success"

result = operator.execute(context)

assert result == "success"
assert operator.entrypoint == "python script.py"
assert "working_dir" in operator.runtime_env

def test_missing_host_config(self):
@patch.object(SubmitRayJob, "execute")
def test_execute_with_entrypoint(self, mock_super_execute):
config = {
"entrypoint": "python my_script.py",
}

def dummy_callable():
pass

context = MagicMock(spec=Context)
operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
mock_super_execute.return_value = "success"

result = operator.execute(context)

assert result == "success"
assert operator.entrypoint == "python my_script.py"

def test_invalid_config_raises_exception(self):
config = {
"host": "http://localhost:8265",
"entrypoint": "python my_script.py",
"runtime_env": {"pip": ["ray"]},
"num_cpus": "invalid_number",
}
@patch.object(SubmitRayJob, "execute")
def test_execute_failure(self, mock_super_execute):
config = {}

def dummy_callable():
pass

with pytest.raises(TypeError):
_RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
context = MagicMock(spec=Context)
operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
mock_super_execute.side_effect = Exception("Ray job failed")

with pytest.raises(AirflowException):
operator.execute(context)

class TestRayTaskDecorator:
def test_build_args_str(self):
config = {}

def dummy_callable(arg1, arg2, kwarg1="default"):
pass

operator = _RayDecoratedOperator(
task_id="test_task",
config=config,
python_callable=dummy_callable,
op_args=["value1", "value2"],
op_kwargs={"kwarg1": "custom"},
)

args_str = operator._build_args_str()
assert args_str == "'value1', 'value2', kwarg1='custom'"

def test_extract_function_name(self):
config = {}

def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

function_name = operator._extract_function_name()
assert function_name == "dummy_callable"


class TestRayTaskDecorator:
def test_ray_task_decorator(self):
@task.ray()
def dummy_function():
return "dummy"

decorator = ray_task(python_callable=dummy_function)
assert isinstance(decorator, _TaskDecorator)
assert isinstance(dummy_function, _TaskDecorator)

def test_ray_task_decorator_with_multiple_outputs(self):
@task.ray(multiple_outputs=True)
def dummy_function():
return {"key": "value"}

decorator = ray_task(python_callable=dummy_function, multiple_outputs=True)
assert isinstance(decorator, _TaskDecorator)
assert isinstance(dummy_function, _TaskDecorator)

def test_ray_task_decorator_with_config(self):
config = {
"num_cpus": 2,
"num_gpus": 1,
"memory": 1024,
}

@task.ray(**config)
def dummy_function():
return "dummy"

assert isinstance(dummy_function, _TaskDecorator)
# We can't directly access the config here, but we can check if the decorator was applied
assert dummy_function.operator_class == _RayDecoratedOperator
Loading