Skip to content

Commit

Permalink
Merge branch 'habana_main' into enable-prefix-caching
Browse files Browse the repository at this point in the history
  • Loading branch information
huijjj authored and irteam sudo-account committed Oct 29, 2024
2 parents b6b369a + 2a38e6f commit b84fc09
Show file tree
Hide file tree
Showing 16 changed files with 520 additions and 84 deletions.
4 changes: 4 additions & 0 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is default
- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs

# Quantization and FP8 model calibration process

The FP8 model calibration procedure has been described as a part of [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.

# Troubleshooting: Tweaking HPU Graphs

If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following the below:
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def main(args: argparse.Namespace):
distributed_executor_backend=args.distributed_executor_backend,
otlp_traces_endpoint=args.otlp_traces_endpoint,
enable_prefix_caching=args.enable_prefix_caching,
max_num_seqs=args.batch_size,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -179,7 +180,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
Expand Down
38 changes: 36 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def run_vllm(
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
weights_load_device: str = None,
use_padding_aware_scheduling: bool = False,
max_num_seqs: int = 256,
max_num_prefill_seqs: int = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -115,6 +119,10 @@ def run_vllm(
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
weights_load_device=weights_load_device,
use_padding_aware_scheduling=use_padding_aware_scheduling,
max_num_seqs=max_num_seqs,
max_num_prefill_seqs=max_num_prefill_seqs,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -181,6 +189,10 @@ async def run_vllm_async(
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
weights_load_device: str = None,
use_padding_aware_scheduling: bool = False,
max_num_seqs: int = 256,
max_num_prefill_seqs: int = None,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
Expand Down Expand Up @@ -208,6 +220,9 @@ async def run_vllm_async(
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
weights_load_device=weights_load_device,
use_padding_aware_scheduling=use_padding_aware_scheduling,
max_num_prefill_seqs=max_num_prefill_seqs,
)

async with build_async_engine_client_from_engine_args(
Expand Down Expand Up @@ -342,7 +357,9 @@ def main(args: argparse.Namespace):
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc
args.disable_async_output_proc, args.weights_load_device,
args.use_padding_aware_scheduling, args.max_num_seqs,
args.max_num_prefill_seqs
]

if args.async_engine:
Expand Down Expand Up @@ -446,7 +463,7 @@ def main(args: argparse.Namespace):
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
Expand Down Expand Up @@ -540,6 +557,23 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument("--weights-load-device",
type=str,
default=None,
choices=DEVICE_OPTIONS,
help='Device on which weights are loaded.')
parser.add_argument("--use-padding-aware-scheduling",
action='store_true',
default=False,
help="Enable padding-aware scheduling.")
parser.add_argument("--max-num-seqs",
type=int,
default=256,
help="Maximum number of requests for single decode.")
parser.add_argument("--max-num-prefill-seqs",
type=int,
default=None,
help="Maximum number of requests for single prefill.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@79f3aa7
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f
10 changes: 10 additions & 0 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.utils import seed_everything

from .allclose_default import get_default_atol, get_default_rtol
Expand All @@ -20,6 +21,9 @@
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
if current_platform.is_hpu():
import habana_frameworks.torch as htorch
CUDA_DEVICES = ['hpu']


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
Expand Down Expand Up @@ -65,6 +69,8 @@ def test_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
if current_platform.is_hpu():
htorch.core.mark_step()
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
torch.testing.assert_close(out_query,
Expand Down Expand Up @@ -120,6 +126,8 @@ def test_batched_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
if current_platform.is_hpu():
htorch.core.mark_step()
out_query, out_key = rope.forward(positions,
query,
key,
Expand Down Expand Up @@ -193,6 +201,8 @@ def test_batched_rotary_embedding_multi_lora(
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
if current_platform.is_hpu():
htorch.core.mark_step()
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
Expand Down
Loading

0 comments on commit b84fc09

Please sign in to comment.