Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Refactor RayRunner so that we can add tracing #3163

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 204 additions & 135 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from ray.data.block import Block as RayDatasetBlock
from ray.data.dataset import Dataset as RayDataset

from daft.execution.physical_plan import MaterializedPhysicalPlan

Check warning on line 76 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L76

Added line #L76 was not covered by tests
from daft.logical.builder import LogicalPlanBuilder
from daft.plan_scheduler import PhysicalPlanScheduler

Expand Down Expand Up @@ -655,6 +656,141 @@
self._actor_pools[name].teardown()
del self._actor_pools[name]

def _construct_dispatch_batch(
self,
execution_id: str,
tasks: MaterializedPhysicalPlan,
dispatches_allowed: int,
) -> tuple[list[PartitionTask], bool]:
"""Constructs a batch of PartitionTasks that should be dispatched

Args:

execution_id: The ID of the current execution.
tasks: The iterator over the physical plan.
dispatches_allowed (int): The maximum number of tasks that can be dispatched in this batch.
Returns:

tuple[list[PartitionTask], bool]: A tuple containing:
- A list of PartitionTasks to be dispatched.
- A pagination boolean indicating whether or not there are more tasks to be had by calling _construct_dispatch_batch again
"""
tasks_to_dispatch: list[PartitionTask] = []

# Loop until:
# - Reached the limit of the number of tasks we are allowed to dispatch
# - Encounter a `None` as the next step (short-circuit and return has_next=False)
while len(tasks_to_dispatch) < dispatches_allowed and self._is_active(execution_id):
next_step = next(tasks)

# CASE: Blocked on already dispatched tasks
# Early terminate and mark has_next=False
if next_step is None:
return tasks_to_dispatch, False

# CASE: A final result
# Place it in the result queue (potentially block on space to be available)
elif isinstance(next_step, MaterializedResult):
self._place_in_queue(execution_id, next_step)

# CASE: No-op task
# Just run it locally immediately.
elif len(next_step.instructions) == 0:
logger.debug("Running task synchronously in main thread: %s", next_step)
assert (
len(next_step.partial_metadatas) == 1
), "No-op tasks must have one output by definition, since there are no instructions to run"
[single_partial] = next_step.partial_metadatas
if single_partial.num_rows is None:
[single_meta] = ray.get(get_metas.remote(next_step.inputs))
accessor = PartitionMetadataAccessor.from_metadata_list(

Check warning on line 706 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L705-L706

Added lines #L705 - L706 were not covered by tests
[single_meta.merge_with_partial(single_partial)]
)
else:
accessor = PartitionMetadataAccessor.from_metadata_list(
[
PartitionMetadata(
num_rows=single_partial.num_rows,
size_bytes=single_partial.size_bytes,
boundaries=single_partial.boundaries,
)
]
)

next_step.set_result([RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs])
next_step.set_done()

# CASE: Actual task that needs to be dispatched
else:
tasks_to_dispatch.append(next_step)

return tasks_to_dispatch, True

def _dispatch_tasks(
self,
tasks_to_dispatch: list[PartitionTask],
daft_execution_config: PyDaftExecutionConfig,
) -> Iterator[tuple[PartitionTask, list[ray.ObjectRef]]]:
"""Iteratively Dispatches a batch of tasks to the Ray backend"""

for task in tasks_to_dispatch:
if task.actor_pool_id is None:
results = _build_partitions(daft_execution_config, task)
else:
actor_pool = self._actor_pools.get(task.actor_pool_id)
assert actor_pool is not None, "Ray actor pool must live for as long as the tasks."
results = _build_partitions_on_actor_pool(task, actor_pool)
logger.debug("%s -> %s", task, results)

yield task, results

def _await_tasks(
self,
inflight_ref_to_task_id: dict[ray.ObjectRef, str],
) -> list[ray.ObjectRef]:
"""Awaits for tasks to be completed. Returns tasks that are ready.

NOTE: This method blocks until at least 1 task is ready. Then it will return as many ready tasks as it can.
"""
if len(inflight_ref_to_task_id) == 0:
return []

Check warning on line 756 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L756

Added line #L756 was not covered by tests

# Await on (any) task to be ready with an unlimited timeout
ray.wait(
list(inflight_ref_to_task_id.keys()),
num_returns=1,
timeout=None,
fetch_local=False,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: slight behavior change here compared to previous awaiting logic

I call a ray.wait here but discard the outputs to just wait on one item to be ready with timeout=None.

Then I subsequently call ray.wait again with a 0.01 timeout to actually retrieve a batch of ready tasks.

I think this logic is a little easier to follow, and gets rid of the weird loop over ("next_one", "next_batch") that we had earlier. Also shouldn't have too much of a performance impact.


# Now, grab as many ready tasks as possible with a 0.01s timeout
timeout = 0.01
num_returns = len(inflight_ref_to_task_id)
readies, _ = ray.wait(
list(inflight_ref_to_task_id.keys()),
num_returns=num_returns,
timeout=timeout,
fetch_local=False,
)

return readies

def _is_active(self, execution_id: str):
"""Checks if the execution for the provided `execution_id` is still active"""
return self.active_by_df.get(execution_id, False)

def _place_in_queue(self, execution_id: str, item: ray.ObjectRef):
"""Places a result into the queue for the provided `execution_id

NOTE: This will block and poll busily until space is available on the queue
`"""
while self._is_active(execution_id):
try:
self.results_by_df[execution_id].put(item, timeout=0.1)
break
except Full:
pass

Check warning on line 792 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L791-L792

Added lines #L791 - L792 were not covered by tests

def _run_plan(
self,
plan_scheduler: PhysicalPlanScheduler,
Expand All @@ -663,14 +799,6 @@
) -> None:
# Get executable tasks from plan scheduler.
results_buffer_size = self.results_buffer_size_by_df[result_uuid]
tasks = plan_scheduler.to_partition_tasks(
psets,
self,
# Attempt to subtract 1 from results_buffer_size because the return Queue size is already 1
# If results_buffer_size=1 though, we can't do much and the total buffer size actually has to be >= 2
# because we have two buffers (the Queue and the buffer inside the `materialize` generator)
None if results_buffer_size is None else max(results_buffer_size - 1, 1),
)

daft_execution_config = self.execution_configs_objref_by_df[result_uuid]
inflight_tasks: dict[str, PartitionTask[ray.ObjectRef]] = dict()
Expand All @@ -684,25 +812,35 @@
f"{datetime.replace(datetime.now(), second=0, microsecond=0).isoformat()[:-3]}.json"
)

def is_active():
return self.active_by_df.get(result_uuid, False)

def place_in_queue(item):
while is_active():
try:
self.results_by_df[result_uuid].put(item, timeout=0.1)
break
except Full:
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got rid of these weird locally-defined callbacks by making them methods, so that I can call them elsewhere without having to pass them around.


with profiler(profile_filename):
tasks = plan_scheduler.to_partition_tasks(
psets,
self,
# Attempt to subtract 1 from results_buffer_size because the return Queue size is already 1
# If results_buffer_size=1 though, we can't do much and the total buffer size actually has to be >= 2
# because we have two buffers (the Queue and the buffer inside the `materialize` generator)
None if results_buffer_size is None else max(results_buffer_size - 1, 1),
)
try:
next_step = next(tasks)

while is_active(): # Loop: Dispatch -> await.
while is_active(): # Loop: Dispatch (get tasks -> batch dispatch).
tasks_to_dispatch: list[PartitionTask] = []

###
# Scheduling Loop:
#
# DispatchBatching ─► Dispatch
# ▲ │ ───────► Await
# └────────────────────────┘ │
# ▲ │
# └───────────────────────────────┘
###
while self._is_active(result_uuid):
###
# Dispatch Loop:
#
# DispatchBatching ─► Dispatch
# ▲ │
# └────────────────────────┘
###
while self._is_active(result_uuid):
# Update available cluster resources
# TODO: improve control loop code to be more understandable and dynamically adjust backlog
cores: int = max(
next(num_cpus_provider) - self.reserved_cores, 1
Expand All @@ -711,136 +849,67 @@
dispatches_allowed = max_inflight_tasks - len(inflight_tasks)
dispatches_allowed = min(cores, dispatches_allowed)

# Loop: Get a batch of tasks.
while len(tasks_to_dispatch) < dispatches_allowed and is_active():
if next_step is None:
# Blocked on already dispatched tasks; await some tasks.
break

elif isinstance(next_step, MaterializedResult):
# A final result.
place_in_queue(next_step)
next_step = next(tasks)

# next_step is a task.

# If it is a no-op task, just run it locally immediately.
elif len(next_step.instructions) == 0:
logger.debug("Running task synchronously in main thread: %s", next_step)
assert (
len(next_step.partial_metadatas) == 1
), "No-op tasks must have one output by definition, since there are no instructions to run"
[single_partial] = next_step.partial_metadatas
if single_partial.num_rows is None:
[single_meta] = ray.get(get_metas.remote(next_step.inputs))
accessor = PartitionMetadataAccessor.from_metadata_list(
[single_meta.merge_with_partial(single_partial)]
)
else:
accessor = PartitionMetadataAccessor.from_metadata_list(
[
PartitionMetadata(
num_rows=single_partial.num_rows,
size_bytes=single_partial.size_bytes,
boundaries=single_partial.boundaries,
)
]
)

next_step.set_result(
[RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs]
)
next_step.set_done()
next_step = next(tasks)

else:
# Add the task to the batch.
tasks_to_dispatch.append(next_step)
next_step = next(tasks)

# Dispatch the batch of tasks.
# Dispatch Batching
tasks_to_dispatch, has_next = self._construct_dispatch_batch(
result_uuid,
tasks,
dispatches_allowed,
)

logger.debug(
"%ss: RayRunner dispatching %s tasks",
(datetime.now() - start).total_seconds(),
len(tasks_to_dispatch),
)

if not is_active():
if not self._is_active(result_uuid):
break

for task in tasks_to_dispatch:
if task.actor_pool_id is None:
results = _build_partitions(daft_execution_config, task)
else:
actor_pool = self._actor_pools.get(task.actor_pool_id)
assert actor_pool is not None, "Ray actor pool must live for as long as the tasks."
results = _build_partitions_on_actor_pool(task, actor_pool)
logger.debug("%s -> %s", task, results)
# Dispatch
for task, result_obj_refs in self._dispatch_tasks(
tasks_to_dispatch,
daft_execution_config,
):
inflight_tasks[task.id()] = task
for result in results:
for result in result_obj_refs:
inflight_ref_to_task[result] = task.id()

pbar.mark_task_start(task)

if dispatches_allowed == 0 or next_step is None:
break

# Await a batch of tasks.
# (Awaits the next task, and then the next batch of tasks within 10ms.)

dispatch = datetime.now()
completed_task_ids = []
for wait_for in ("next_one", "next_batch"):
if not is_active():
break

if wait_for == "next_one":
num_returns = 1
timeout = None
elif wait_for == "next_batch":
num_returns = len(inflight_ref_to_task)
timeout = 0.01 # 10ms

if num_returns == 0:
# Break the dispatch batching/dispatch loop if no more dispatches allowed, or physical plan
# needs work for forward progress
if dispatches_allowed == 0 or not has_next:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: behavior change here. We now use a has_next variable that is returned from self._construct_dispatch_batch to figure out whether or now we should break the DispatchBatching -> Dispatch loop.

Previously the scheduling loop had to keep track of next_step which was super weird and unwieldy, and caused us to call next(tasks) in a bunch of random places scattered throughout the loop.

break

readies, _ = ray.wait(
list(inflight_ref_to_task.keys()),
num_returns=num_returns,
timeout=timeout,
fetch_local=False,
)

for ready in readies:
if ready in inflight_ref_to_task:
task_id = inflight_ref_to_task[ready]
completed_task_ids.append(task_id)
# Mark the entire task associated with the result as done.
task = inflight_tasks[task_id]
task.set_done()

if isinstance(task, SingleOutputPartitionTask):
del inflight_ref_to_task[ready]
elif isinstance(task, MultiOutputPartitionTask):
for partition in task.partitions():
del inflight_ref_to_task[partition]

pbar.mark_task_done(task)
del inflight_tasks[task_id]

logger.debug(
"%ss to await results from %s", (datetime.now() - dispatch).total_seconds(), completed_task_ids
)

if next_step is None:
next_step = next(tasks)
###
# Await:
# Wait for some work to be completed from the current wave's dispatch
# Then we perform the necessary record-keeping on tasks that were retrieved as ready.
###
readies = self._await_tasks(inflight_ref_to_task)
for ready in readies:
if ready in inflight_ref_to_task:
task_id = inflight_ref_to_task[ready]

# Mark the entire task associated with the result as done.
task = inflight_tasks[task_id]
task.set_done()

if isinstance(task, SingleOutputPartitionTask):
del inflight_ref_to_task[ready]
elif isinstance(task, MultiOutputPartitionTask):
for partition in task.partitions():
del inflight_ref_to_task[partition]

pbar.mark_task_done(task)
del inflight_tasks[task_id]

except StopIteration as e:
place_in_queue(e)
self._place_in_queue(result_uuid, e)

# Ensure that all Exceptions are correctly propagated to the consumer before reraising to kill thread
except Exception as e:
place_in_queue(e)
self._place_in_queue(result_uuid, e)
pbar.close()
raise

Expand Down
Loading