Skip to content

Commit

Permalink
xpu: test py_limited_api with SyclExtension
Browse files Browse the repository at this point in the history
Commit extends existing CUDA test to cover XPU SyclExtension
case for the same feature - `py_limited_api`.

NOTE: THE CHANGE CAN NOT BE MERGED AS IS

Change requires update of the commit pin for torch-xpu-ops.

Requires: intel/torch-xpu-ops#1405
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Feb 26, 2025
1 parent adf0f4f commit f2a5265
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@
import torch


so_files = list(Path(__file__).parent.glob("_C*.so"))
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
if torch.cuda.is_available():
cuda_so_files = list(Path(__file__).parent.glob("cuda*.so"))
assert (
len(cuda_so_files) == 1
), f"Expected one cuda*.so file, found {len(cuda_so_files)}"
torch.ops.load_library(cuda_so_files[0])

if torch.xpu.is_available():
sycl_so_files = list(Path(__file__).parent.glob("sycl*.so"))
assert (
len(sycl_so_files) == 1
), f"Expected one sycl*.so file, found {len(sycl_so_files)}"
torch.ops.load_library(sycl_so_files[0])

from . import ops

Expand All @@ -15,12 +25,19 @@
# The following is used to assert the ultra_norm op is properly loaded and
# calculates correct results upon import of this extension.

inputs = [
torch.tensor([1.0, 2.0, 3.0], device="cuda"),
torch.tensor([-4.0, -5.0, -6.0], device="cuda"),
]

assert torch.equal(
ops.ultra_norm(inputs),
torch.norm(torch.tensor([1.0, 2.0, 3.0, -4.0, -5.0, -6.0], device="cuda")),
)
devices = []
if torch.cuda.is_available():
devices.append("cuda")
if torch.xpu.is_available():
devices.append("xpu")

for device in devices:
inputs = [
torch.tensor([1.0, 2.0, 3.0], device=device),
torch.tensor([-4.0, -5.0, -6.0], device=device),
]

assert torch.equal(
ops.ultra_norm(inputs),
torch.norm(torch.tensor([1.0, 2.0, 3.0, -4.0, -5.0, -6.0], device=device)),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <ATen/ops/_foreach_norm_native.h>
#include <ATen/ops/cat_xpu_dispatch.h>
#include <ATen/ops/norm_xpu_dispatch.h>
#include <ATen/ops/unsqueeze.h>
#include <torch/library.h>

at::Tensor ultra_norm(at::TensorList inputs) {
auto res = at::native::foreach_tensor_norm_xpu(inputs);
std::vector<at::Tensor> unsqueezed;
for (const auto& scalar_tensor : res) {
unsqueezed.push_back(at::unsqueeze(scalar_tensor, 0));
}
auto stacked = at::xpu::cat(unsqueezed);
return at::xpu::norm(stacked, 2, at::IntArrayRef{}, false);
}

TORCH_LIBRARY_IMPL(python_agnostic, XPU, m) {
m.impl("python_agnostic::ultra_norm", &ultra_norm);
}
36 changes: 26 additions & 10 deletions test/cpp_extensions/python_agnostic_extension/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from setuptools import setup

from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, SyclExtension


ROOT_DIR = Path(__file__).parent
Expand Down Expand Up @@ -40,17 +41,32 @@ def get_extension():
"cxx": ["-fdiagnostics-color=always"],
}

sources = list(CSRC_DIR.glob("**/*.cu"))
extensions = []
if torch.cuda.is_available():
cuda_sources = list(CSRC_DIR.glob("**/*.cu"))

return [
CUDAExtension(
"python_agnostic._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
extensions.append(
CUDAExtension(
"python_agnostic.cuda",
sources=sorted(str(s) for s in cuda_sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
)
]
if torch.xpu.is_available():
sycl_sources = list(CSRC_DIR.glob("**/*.sycl"))

extensions.append(
SyclExtension(
"python_agnostic.sycl",
sources=sorted(str(s) for s in sycl_sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
)
return extensions


setup(
Expand Down
3 changes: 2 additions & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
TEST_WITH_SLOW_GRADCHECK,
TEST_XPU,
)


Expand Down Expand Up @@ -1078,7 +1079,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
return return_code
if sys.platform != "win32":
exts_to_build = [(install_cmd, "no_python_abi_suffix_test")]
if TEST_CUDA:
if TEST_CUDA or TEST_XPU:
exts_to_build.append((wheel_cmd, "python_agnostic_extension"))
for cmd, extension_dir in exts_to_build:
return_code = shell(
Expand Down
5 changes: 4 additions & 1 deletion test/test_cpp_extensions_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def test_cuda_dlink_libs(self):
test = cuda_dlink.add(a, b)
self.assertEqual(test, ref)

@unittest.skipIf(not TEST_CUDA, "python_agnostic is a CUDA extension + needs CUDA")
@unittest.skipIf(
not (TEST_CUDA or TEST_XPU),
"python_agnostic is a CUDA/XPU extension + needs CUDA/XPU",
)
@unittest.skipIf(not common.IS_LINUX, "test requires linux tools ldd and nm")
def test_python_agnostic(self):
# For this test, run_test.py will call `python setup.py bdist_wheel` in the
Expand Down
2 changes: 1 addition & 1 deletion third_party/xpu.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
306a0ffb6e0cae27c5bd9a3b9cd378048c8e00e7
pr_1405

0 comments on commit f2a5265

Please sign in to comment.