Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add load stages #760

Merged
merged 34 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dac9f84
Create N replicas per `Step`
gabrielmbmb Jun 20, 2024
372e679
Update `_BatchManager` to handle batch sorting uncertainty
gabrielmbmb Jun 21, 2024
a7173f3
Add multiple replicas test
gabrielmbmb Jun 21, 2024
6761c4e
Merge branch 'develop' into multiple-workers
gabrielmbmb Jun 21, 2024
ff824b5
Fix unit tests
gabrielmbmb Jun 24, 2024
b9c0d63
Fix `next_expected_seq_no` needed to be updated if
gabrielmbmb Jun 24, 2024
e1daa84
Update `set_next_expected_batch_seq_no` only if no `data`
gabrielmbmb Jun 25, 2024
b093a25
Fix `next_expected_seq_no` with `routing_batch_function`
gabrielmbmb Jun 25, 2024
ac931d7
Remove prints
gabrielmbmb Jun 26, 2024
ac7ef09
Add `StepResource` import
gabrielmbmb Jun 26, 2024
5c685e9
Add missing return type hint
gabrielmbmb Jun 26, 2024
849a806
Add `StepResources` docs
gabrielmbmb Jun 26, 2024
31392ee
Add `get_steps_load_stages` method
gabrielmbmb Jun 26, 2024
886e9a5
Update to load steps in stages
gabrielmbmb Jun 26, 2024
31e52df
Add `_teardown` method
gabrielmbmb Jun 26, 2024
744ea6e
Add load stages
gabrielmbmb Jun 27, 2024
1c89923
Merge branch 'develop' into steps-load-stages
gabrielmbmb Jun 27, 2024
972d1d7
Add printing info about stages
gabrielmbmb Jun 27, 2024
727b020
Refactor load stages to avoid race conditions
gabrielmbmb Jun 27, 2024
5792284
Add load stages integration test
gabrielmbmb Jun 27, 2024
4713f5b
Fix unit tests
gabrielmbmb Jun 27, 2024
8a00327
Add unit tests for new methods
gabrielmbmb Jun 27, 2024
a3300e1
Move send last batch message
gabrielmbmb Jun 28, 2024
ae04b52
Refactor to make it work with routing batch function
gabrielmbmb Jun 28, 2024
7bcfcec
Add integration test for load stages & routing batch function
gabrielmbmb Jun 28, 2024
8073db6
Update docs to tell about resources as runtime parameters
gabrielmbmb Jun 28, 2024
d83fe3b
Merge branch 'develop' into steps-load-stages
gabrielmbmb Jun 28, 2024
5207736
Add missing doc pages
gabrielmbmb Jun 28, 2024
960d5d3
Merge branch 'develop' into steps-load-stages
gabrielmbmb Jul 1, 2024
b9091a3
Update to load stages from cache
gabrielmbmb Jul 1, 2024
f489987
Fix bugs requesting initial batches
gabrielmbmb Jul 1, 2024
d804dc2
Add integration tests for recovering states from cache
gabrielmbmb Jul 1, 2024
4942121
Remove atexit
gabrielmbmb Jul 1, 2024
b5605fb
Fix docstring typos
gabrielmbmb Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/mixins/requirements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: distilabel.mixins.requirements.RequirementsMixin
1 change: 1 addition & 0 deletions docs/api/mixins/runtime_parameters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: distilabel.mixins.runtime_parameters.RuntimeParametersMixin
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,18 @@ with Pipeline(name="resources") as pipeline:
)
```

In the example above, we're creating a `PrometheusEval` task (remember that `Task`s are `Step`s) that will use `vLLM` to serve `prometheus-eval/prometheus-7b-v2.0` model. This task is resource intensive as it requires an LLM, which in turn requires a GPU to run fast. With that in mind, we have specified the `resources` required for the task using the [`StepResources`][distilabel.steps.base.StepResources] class, and we have defined that we need `1` GPU and `1` CPU per replica of the task. In addition, we have defined that we need `2` replicas i.e. we will run two instances of the task so the computation for the whole dataset runs faster. When running the pipeline, `distilabel` will create the tasks in nodes that have available the specified resources.
In the example above, we're creating a `PrometheusEval` task (remember that `Task`s are `Step`s) that will use `vLLM` to serve `prometheus-eval/prometheus-7b-v2.0` model. This task is resource intensive as it requires an LLM, which in turn requires a GPU to run fast. With that in mind, we have specified the `resources` required for the task using the [`StepResources`][distilabel.steps.base.StepResources] class, and we have defined that we need `1` GPU and `1` CPU per replica of the task. In addition, we have defined that we need `2` replicas i.e. we will run two instances of the task so the computation for the whole dataset runs faster. In addition, `StepResources` uses the [RuntimeParametersMixin][distilabel.mixins.runtime_parameters.RuntimeParametersMixin], so we can also specify the resources for each step when running the pipeline:

```python
...

if __name__ == "__main__":
pipeline.run(
parameters={
prometheus.name: {"resources": {"replicas": 2, "cpus": 1, "gpus": 1}}
}
)
```

And that's it! When running the pipeline, `distilabel` will create the tasks in nodes that have available the specified resources.

5 changes: 4 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ theme:
logo: assets/logo.svg
favicon: assets/logo.svg
icon:
repo: fontawesome/brands/github-alt
repo: fontawesome/brands/github
features:
- navigation.instant
- navigation.sections
Expand Down Expand Up @@ -210,6 +210,9 @@ nav:
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
- Typing: "api/pipeline/typing.md"
- Utils: "api/pipeline/utils.md"
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
- Distiset: "api/distiset.md"
- CLI: "api/cli.md"
- Community:
Expand Down
43 changes: 42 additions & 1 deletion src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Iterable,
List,
Set,
Tuple,
Type,
Union,
)
Expand Down Expand Up @@ -258,6 +259,46 @@ def get_total_replica_count(self) -> int:
"""
return sum([self.get_step_replica_count(step_name) for step_name in self.G])

def get_steps_load_stages(self) -> Tuple[List[List[str]], List[List[str]]]:
"""Gets the stages in which the `Step`s of the `Pipeline` should be loaded. Stages
are determined by `GlobalStep`s as they receive all the data at once, which means
that a `GlobalStep` is not required to be loaded until all their previous steps
have finished their execution, and the successors of the global step are not required
to be loaded until the global has finished.

Returns:
A tuple with the first element containing asorted list by stage containing
lists with the names of the steps of the stage, and the second element a list
sorted by stage containing lists with the names of the last steps of the stage.
"""

def _get_stage_last_steps(stage_steps: List[str]) -> List[str]:
subgraph = self.G.subgraph(stage_steps)
return sorted(
[node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]
)

stages = []
current_stage = []
stages_last_steps = []

for step_name in nx.topological_sort(self.G):
step: "_Step" = self.get_step(step_name)[STEP_ATTR_NAME]
if not step.is_global:
current_stage.append(step_name)
else:
stages.append(current_stage)
stages_last_steps.append(_get_stage_last_steps(current_stage))
stages.append([step_name])
stages_last_steps.append([step_name])
current_stage = []

if current_stage:
stages.append(current_stage)
stages_last_steps.append(_get_stage_last_steps(current_stage))

return stages, stages_last_steps

def validate(self) -> None:
"""Validates that the `Step`s included in the pipeline are correctly connected and
have the correct inputs and outputs.
Expand Down Expand Up @@ -299,7 +340,7 @@ def validate(self) -> None:
# Validate routing batch function (if any)
predecessors = list(self.get_step_predecessors(step.name)) # type: ignore
self._validate_convergence_step(
step,
step, # type: ignore
predecessors,
steps_receiving_routed_batches, # type: ignore
)
Expand Down
Loading
Loading