Skip to content

Commit

Permalink
sarkar/Add htrandom generator for hpu (#246)
Browse files Browse the repository at this point in the history
To repro:

start server:
`VLLM_SKIP_WARMUP=true python -m vllm.entrypoints.openai.api_server`

send a request (this works fine):
```
 curl -v http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{"model": "facebook/opt-125m","prompt": "The future of AI is ","max_tokens": 100,"temperature": 0}'
```

if request has a seed it fails:
```
curl -v http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{"model": "facebook/opt-125m","prompt": "The future of AI is ","max_tokens": 100,"temperature": 0, "seed" : 37}'
```

Failure happens here:

[vllm-fork/vllm/model_executor/sampling_metadata.py at habana_main ·
HabanaAI/vllm-fork](https://github.com/HabanaAI/vllm-fork/blob/habana_main/vllm/model_executor/sampling_metadata.py#L220)

```
if sampling_params.seed is not None:
                seq_group_metadata.state.generator = torch.Generator(
                    device=device).manual_seed(sampling_params.seed)
```
 

`RuntimeError: Device type HPU is not supported for torch.Generator()
api.`

This PR fixes above issue by using htrandom [Intel Gaudi PyTorch Python
API (habana_frameworks.torch) — Gaudi Documentation 1.17.1
documentation](https://docs.habana.ai/en/latest/PyTorch/Reference/Python_Packages.html?highlight=htrandom#random-number-generator-apis)
  • Loading branch information
ssarkar2 authored Oct 28, 2024
1 parent 4fd5c4c commit 2a38e6f
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
Expand Down Expand Up @@ -266,8 +267,14 @@ def _prepare_seq_groups(

if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
generator = torch.Generator(device=device).manual_seed(
sampling_params.seed)
if current_platform.is_hpu():
import habana_frameworks.torch.hpu.random as htrandom
generator = \
htrandom.default_generators[
0].manual_seed(sampling_params.seed)
else:
generator = torch.Generator(device=device).manual_seed(
sampling_params.seed)
if generators is not None:
generators[seq_group_metadata.request_id] = generator

Expand Down

0 comments on commit 2a38e6f

Please sign in to comment.