Skip to content

Commit

Permalink
[Hardware][Intel GPU] add XPU bf16 support (#12392)
Browse files Browse the repository at this point in the history
Signed-off-by: Kunshang Ji <[email protected]>
  • Loading branch information
jikunshang authored Feb 2, 2025
1 parent f8ece6e commit f256ebe
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/getting_started/installation/gpu/xpu.inc.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ VLLM_TARGET_DEVICE=xpu python setup.py install

:::{note}
- FP16 is the default data type in the current XPU backend. The BF16 data
type will be supported in the future.
type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet.
:::

## Set up using Docker
Expand Down
23 changes: 20 additions & 3 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
logger.warning(
"bfloat16 is not fully supported on XPU, casting to float16.")
model_config.dtype = torch.float16
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
logger.warning(
"bfloat16 is only supported on Intel Data Center GPU, "
"Intel Arc GPU is not supported yet. Your device is %s,"
"which is not supported. will fallback to float16",
cls.get_device_name())
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
Expand Down Expand Up @@ -116,3 +121,15 @@ def get_current_memory_usage(cls,
) -> float:
torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device)

@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
if device_name.count("arc") > 0:
return False
elif device_name.count("data center gpu") > 0:
return True
else:
logger.warning("Unknown device name %s, always use float16",
device_name)
return False

0 comments on commit f256ebe

Please sign in to comment.