Skip to content

Commit

Permalink
Update unload method from vLLM to properly free resources (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb authored Dec 3, 2024
1 parent fa13ae1 commit f8e41cd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- --fuzzy-match-generates-todo

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.8.1
hooks:
- id: ruff
args: [--fix]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ distilabel = "distilabel.cli.app:app"
"distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin"

[project.optional-dependencies]
dev = ["ruff == 0.6.2", "pre-commit >= 3.5.0"]
dev = ["ruff == 0.8.1", "pre-commit >= 3.5.0"]
docs = [
"mkdocs-material >=9.5.17",
"mkdocstrings[python] >= 0.24.0",
Expand Down
21 changes: 21 additions & 0 deletions src/distilabel/models/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import gc
import json
from functools import cached_property
from typing import (
Expand Down Expand Up @@ -214,11 +216,30 @@ def load(self) -> None:

def unload(self) -> None:
"""Unloads the `vLLM` model."""
self._cleanup_vllm_model()
self._model = None # type: ignore
self._tokenizer = None # type: ignore
CudaDevicePlacementMixin.unload(self)
super().unload()

def _cleanup_vllm_model(self) -> None:
import torch
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
)

destroy_model_parallel()
destroy_distributed_environment()
del self._model.llm_engine.model_executor
del self._model
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
Expand Down

0 comments on commit f8e41cd

Please sign in to comment.