Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel GPU] add XPU bf16 support #12392

Merged
merged 3 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/getting_started/installation/gpu/xpu.inc.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ VLLM_TARGET_DEVICE=xpu python setup.py install

:::{note}
- FP16 is the default data type in the current XPU backend. The BF16 data
type will be supported in the future.
type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet.
:::

## Set up using Docker
Expand Down
23 changes: 20 additions & 3 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
logger.warning(
"bfloat16 is not fully supported on XPU, casting to float16.")
model_config.dtype = torch.float16
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
logger.warning(
"bfloat16 is only supported on Intel Data Center GPU, "
"Intel Arc GPU is not supported yet. Your device is %s,"
"which is not supported. will fallback to float16",
cls.get_device_name())
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
Expand Down Expand Up @@ -116,3 +121,15 @@ def get_current_memory_usage(cls,
) -> float:
torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device)

@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
if device_name.count("arc") > 0:
return False
elif device_name.count("data center gpu") > 0:
return True
else:
logger.warning("Unknown device name %s, always use float16",
device_name)
return False