From 745e4bdf00856b9ba11c9deeacb55708c2a5426f Mon Sep 17 00:00:00 2001 From: majing Date: Thu, 11 Jul 2024 07:22:53 -0700 Subject: [PATCH 01/10] Add aten::histogram and variant Signed-off-by: majing --- src/ATen/native/xpu/Histogram.cpp | 200 +++++++++++++++++ src/ATen/native/xpu/sycl/HistogramKernels.h | 23 ++ .../native/xpu/sycl/HistogramddKernels.cpp | 212 ++++++++++++++++++ test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 4 + 5 files changed, 440 insertions(+) create mode 100644 src/ATen/native/xpu/Histogram.cpp create mode 100644 src/ATen/native/xpu/sycl/HistogramKernels.h create mode 100644 src/ATen/native/xpu/sycl/HistogramddKernels.cpp diff --git a/src/ATen/native/xpu/Histogram.cpp b/src/ATen/native/xpu/Histogram.cpp new file mode 100644 index 000000000..59260e56f --- /dev/null +++ b/src/ATen/native/xpu/Histogram.cpp @@ -0,0 +1,200 @@ +#include +#include +#include +#include +#include + +namespace at { + +/* Checks properties of input tensors input, bins, and weight. + */ +void histogramdd_check_inputs( + const Tensor& input, + const Tensor& bins, + const std::optional& weight) { + if (weight.has_value) { + TORCH_CHECK( + weight->device() == input.device(), + "weight and input need to be on the same device.") + } + auto input_dtype = input.dtype(); + auto bins_dtype = bins.dtype(); + TORCH_CHECK( + input_dtype == bins_dtype, + "torch.histogramdd: input tensor and bins tensors should", + " have the same dtype, but got input with dtype ", + input_dtype, + " and bins with dtype ", + bins_dtype); + + const int64_t bins_dim = bins.dim(); + TORCH_CHECK( + bins_dim == 1, + "torch.histogramdd: bins tensor should have one dimension,", + " but got ", + bins_dim, + " dimensions in the bin tensor"); + + const int64_t numel = bins.numel(); + TORCH_CHECK( + numel > 0, + "torch.histogramdd: bins tensor should have at least 1 element,", + " but got ", + numel, + " elements in the bin tensor"); + + if (weight.has_value()) { + TORCH_CHECK( + input.dtype() == weight.value().dtype(), + "torch.histogramdd: if weight tensor is provided, ", + "input tensor and weight tensor should have the same dtype, ", + "but got input(", + input.dtype(), + ")", + ", and weight(", + weight.value().dtype(), + ")"); + + /* If a weight tensor is provided, we expect its shape to match that of + * the input tensor excluding its innermost dimension N. + */ + auto input_sizes = input.sizes().vec(); + + auto weight_sizes = weight.value().sizes().vec(); + if (weight_sizes.empty()) { + // correctly handle scalars + weight_sizes = {1}; + } + + TORCH_CHECK( + input_sizes == weight_sizes, + "torch.histogramdd: if weight tensor is provided it should have", + " the same shape as the input tensor excluding its innermost ", + "dimension, but got input with shape ", + input.sizes(), + " and weight ", + "with shape ", + weight.value().sizes()); + } +} + +/* Checks properties of output tensors hist and bin_edges, then resizes them. + */ +void histogramdd_prepare_out( + const Tensor& input, + int64_t bin_ct, + const Tensor& hist, + const Tensor& bin_edges) { + TORCH_CHECK( + input.dtype() == hist.dtype(), + "torch.histogram: input tensor and hist tensor should", + " have the same dtype, but got input ", + input.dtype(), + " and hist ", + hist.dtype()); + + TORCH_CHECK( + input.dtype() == bin_edges.dtype(), + "torch.histogram: input tensor and bin_edges tensor should", + " have the same dtype, but got input ", + input.dtype(), + " and bin_edges ", + bin_edges.dtype()); + + TORCH_CHECK( + bin_ct > 0, "torch.histogram(): bins must be > 0, but got ", bin_ct); + + at::native::resize_output(bin_edges, {bin_ct + 1}); + + at::native::resize_output(hist, {bin_ct}); +} + +void histogramdd_prepare_out( + const Tensor& input, + const Tensor& bins, + const Tensor& hist, + const Tensor& bin_edges) { + int64_t bin_ct = bins.numel() - 1; + histogramdd_prepare_out(input, bin_ct, hist, bin_edges); +} + +static Tensor& histogramdd_out( + const Tensor& self, + const Tensor& bins, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& bin_edges) { + globalContext().alertNotDeterministic("histogram_bin_xpu"); + histogramdd_check_inputs(self, bins, weight); + histogramdd_prepare_out(self, bins, hist, bin_edges); + + bin_edges.copy_(bins); + + at::native::xpu::histogramdd_kernel(self, weight, density, hist, bin_edges); + return hist; +} + +std::tuple XPUNativeFunctions::histogram_out( + const Tensor& self, + const Tensor& bins, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& bin_edges) { + Tensor reshaped_self = self.reshape({self.numel()}); + std::optional reshaped_weight = weight.has_value() + ? weight.value().reshape({weight.value().numel()}) + : weight; + + histogramdd_out( + reshaped_self, bins, reshaped_weight, density, hist, bin_edges); + + return std::forward_as_tuple(hist, bin_edges); +} + +std::tuple XPUNativeFunctions::histogram( + const Tensor& self, + const Tensor& bins, + const std::optional& weight, + bool density) { + Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous); + Tensor bin_edges = at::empty({0}, bins.options(), MemoryFormat::Contiguous); + return histogram_out(self, bins, weight, density, hist, bin_edges); +} + +std::tuple XPUNativeFunctions::histogram_out( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& bin_edges) { + globalContext().alertNotDeterministic("histogram_bin_count_xpu"); + Tensor reshaped_self = self.reshape({self.numel()}); + std::optional reshaped_weight = weight.has_value() + ? weight.value().reshape({weight.value().numel()}) + : weight; + + histogramdd_prepare_out(reshaped_self, bin_ct, hist, bin_edges); + histogramdd_check_inputs(reshaped_self, bin_edges, reshaped_weight); + + at::native::xpu::histogramdd_linear_kernel( + reshaped_self, bin_ct, range, reshaped_weight, density, hist, bin_edges); + return std::forward_as_tuple(hist, bin_edges); +} + +std::tuple XPUNativeFunctions::histogram( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density) { + Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous); + Tensor bin_edges_out = at::empty({0}, self.options()); + return histogram_out( + self, bin_ct, range, weight, density, hist, bin_edges_out); +} + +} // namespace at \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/HistogramKernels.h b/src/ATen/native/xpu/sycl/HistogramKernels.h new file mode 100644 index 000000000..da153186a --- /dev/null +++ b/src/ATen/native/xpu/sycl/HistogramKernels.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void histogramdd_kernel( + const Tensor& self, + const std::optional& weight, + bool density, + Tensor& hist, + const Tensor& bin_edges); + +void histogramdd_linear_kernel( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& out_bin_edges); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp new file mode 100644 index 000000000..84d4ec254 --- /dev/null +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -0,0 +1,212 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native::xpu { + +template +struct HistogramddKernelFunctor { + void operator()(sycl::nd_item<1> item_id) const { + int64_t wi_id = item_id.get_global_id(); + if (wi_id < input_size_ * bin_size_) { + int64_t ele_idx = wi_id / bin_size_; + int64_t bin_idx = wi_id % bin_size_; + + if (input_[ele_idx] >= bin_edges_[bin_idx] && + input_[ele_idx] < bin_edges_[bin_idx + 1]) { + scalar_t value = weight_ ? weight_[ele_idx] : (scalar_t)1; + atomicAdd((sycl_global_ptr)(hist_ + bin_idx), value); + return; + } + + if (bin_idx == 0 && input_[ele_idx] == bin_edges_[bin_size_]) { + scalar_t value = weight_ ? weight_[ele_idx] : (scalar_t)1; + atomicAdd((sycl_global_ptr)(hist_ + bin_size_ - 1), value); + } + } + } + + HistogramddKernelFunctor( + const scalar_t* input, + const scalar_t* bin_edges, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size) + : input_(input), + bin_edges_(bin_edges), + hist_(hist), + weight_(weight), + input_size_(input_size), + bin_size_(bin_size) {} + + private: + const scalar_t* input_; + const scalar_t* bin_edges_; + scalar_t* hist_; + const scalar_t* weight_; + int64_t input_size_; + int64_t bin_size_; +}; + +template +void histogramdd_template( + const scalar_t* input, + const scalar_t* bin_edges, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size) { + HistogramddKernelFunctor kfn( + input, bin_edges, hist, weight, input_size, bin_size); + const int64_t work_group_size = syclMaxWorkGroupSize(kfn); + const int64_t num_wg = + (input_size * bin_size + work_group_size - 1) / work_group_size; + sycl_kernel_submit( + num_wg * work_group_size, work_group_size, getCurrentSYCLQueue(), kfn); +} + +template +struct HistogramddLinearKernelFunctor { + void operator()(sycl::nd_item<1> item_id) const { + int64_t wi_id = item_id.get_global_id(); + if (wi_id < input_size_) { + scalar_t i_value = input_[wi_id]; + if (i_value >= leftmost_edge_ && i_value <= rightmost_edge_) { + int64_t bin = + (int64_t)(((i_value - leftmost_edge_)) * bin_size_ / (rightmost_edge_ - leftmost_edge_)); + if (bin == bin_size_) + bin -= 1; + scalar_t value = weight_ ? weight_[wi_id] : (scalar_t)1; + atomicAdd((sycl_global_ptr)(hist_ + bin), value); + } + } + } + + HistogramddLinearKernelFunctor( + const scalar_t* input, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size, + double leftmost_edge, + double rightmost_edge) + : input_(input), + hist_(hist), + weight_(weight), + input_size_(input_size), + bin_size_(bin_size), + leftmost_edge_(leftmost_edge), + rightmost_edge_(rightmost_edge) {} + + private: + const scalar_t* input_; + scalar_t* hist_; + const scalar_t* weight_; + int64_t input_size_; + int64_t bin_size_; + double leftmost_edge_; + double rightmost_edge_; +}; + +template +void histogramdd_linear_template( + const scalar_t* input, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size, + double leftmost_edge, + double rightmost_edge) { + HistogramddLinearKernelFunctor kfn( + input, hist, weight, input_size, bin_size, leftmost_edge, rightmost_edge); + const int64_t work_group_size = syclMaxWorkGroupSize(kfn); + const int64_t num_wg = (input_size + work_group_size - 1) / work_group_size; + sycl_kernel_submit( + num_wg * work_group_size, work_group_size, getCurrentSYCLQueue(), kfn); +} + +void histogramdd_kernel( + const Tensor& self, + const std::optional& weight, + bool density, + Tensor& hist, + const Tensor& bin_edges) { + hist.fill_(0); + // TODO: contiguous ? + + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, self.scalar_type(), "histogram_xpu", [&]() { + histogramdd_template( + self.data_ptr(), + bin_edges.data_ptr(), + hist.data_ptr(), + weight.has_value() ? weight->data_ptr() : nullptr, + self.numel(), + bin_edges.numel() - 1); + }); + + if (density) { + const auto hist_sum = hist.sum(); + hist.div_(hist_sum); + } +} + +void histogramdd_linear_kernel( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& out_bin_edges) { + hist.fill_(0); + + double leftmost_edge, rightmost_edge; + if (!range.has_value()) { + auto extrema = at::aminmax(self); + leftmost_edge = std::get<0>(extrema).item(); + rightmost_edge = std::get<1>(extrema).item(); + } else { + leftmost_edge = range.value()[0]; + rightmost_edge = range.value()[1]; + } + + at::linspace_out(out_bin_edges, leftmost_edge, rightmost_edge, bin_ct + 1); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, self.scalar_type(), "histogram_linear_xpu", [&]() { + histogramdd_linear_template( + self.data_ptr(), + hist.data_ptr(), + weight.has_value() ? weight->data_ptr() : nullptr, + self.numel(), + bin_ct, + leftmost_edge, + rightmost_edge); + }); + + if (density) { + const auto hist_sum = hist.sum(); + hist.div_(hist_sum); + } +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 35c29d96b..f19a93dda 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -157,6 +157,7 @@ "renorm", "lerp", "conj_physical", + "histogram", ] diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index f103a7795..262b6e020 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -519,3 +519,7 @@ supported: - ceil - ceil_ - ceil.out + - histogram.bins_tensor + - histogram.bins_tensor_out + - histogram.bin_ct + - histogram.bin_ct_out From 0771439f785ba360c7efb8b028634863050c3d27 Mon Sep 17 00:00:00 2001 From: majing Date: Thu, 11 Jul 2024 08:00:43 -0700 Subject: [PATCH 02/10] add comments Signed-off-by: majing --- src/ATen/native/xpu/Histogram.cpp | 5 +---- src/ATen/native/xpu/sycl/HistogramddKernels.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ATen/native/xpu/Histogram.cpp b/src/ATen/native/xpu/Histogram.cpp index 59260e56f..915de051c 100644 --- a/src/ATen/native/xpu/Histogram.cpp +++ b/src/ATen/native/xpu/Histogram.cpp @@ -12,7 +12,7 @@ void histogramdd_check_inputs( const Tensor& input, const Tensor& bins, const std::optional& weight) { - if (weight.has_value) { + if (weight.has_value()) { TORCH_CHECK( weight->device() == input.device(), "weight and input need to be on the same device.") @@ -55,9 +55,6 @@ void histogramdd_check_inputs( weight.value().dtype(), ")"); - /* If a weight tensor is provided, we expect its shape to match that of - * the input tensor excluding its innermost dimension N. - */ auto input_sizes = input.sizes().vec(); auto weight_sizes = weight.value().sizes().vec(); diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp index 84d4ec254..fa71ba77b 100644 --- a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -27,6 +27,7 @@ struct HistogramddKernelFunctor { int64_t ele_idx = wi_id / bin_size_; int64_t bin_idx = wi_id % bin_size_; + // [left, right) if (input_[ele_idx] >= bin_edges_[bin_idx] && input_[ele_idx] < bin_edges_[bin_idx + 1]) { scalar_t value = weight_ ? weight_[ele_idx] : (scalar_t)1; @@ -34,6 +35,7 @@ struct HistogramddKernelFunctor { return; } + // For last bin, [left, right] if (bin_idx == 0 && input_[ele_idx] == bin_edges_[bin_size_]) { scalar_t value = weight_ ? weight_[ele_idx] : (scalar_t)1; atomicAdd((sycl_global_ptr)(hist_ + bin_size_ - 1), value); @@ -64,6 +66,7 @@ struct HistogramddKernelFunctor { int64_t bin_size_; }; +// For one dimension case template void histogramdd_template( const scalar_t* input, @@ -124,6 +127,7 @@ struct HistogramddLinearKernelFunctor { double rightmost_edge_; }; +// For one dimension case template void histogramdd_linear_template( const scalar_t* input, @@ -146,10 +150,9 @@ void histogramdd_kernel( const std::optional& weight, bool density, Tensor& hist, - const Tensor& bin_edges) { + const Tensor& bin_edges_) { hist.fill_(0); - // TODO: contiguous ? - + Tensor bin_edges = bin_edges_.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, self.scalar_type(), "histogram_xpu", [&]() { histogramdd_template( From 85f42f4bdb7623373242b632b7b3e3fbf46544dd Mon Sep 17 00:00:00 2001 From: majing Date: Thu, 11 Jul 2024 22:38:50 -0700 Subject: [PATCH 03/10] Enable CI test and fix CI failures Signed-off-by: majing --- .../native/xpu/sycl/HistogramddKernels.cpp | 20 ++++++++++++++----- test/xpu/extended/test_ops_xpu.py | 6 ++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp index fa71ba77b..36c92b7f0 100644 --- a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -167,6 +167,8 @@ void histogramdd_kernel( if (density) { const auto hist_sum = hist.sum(); hist.div_(hist_sum); + Tensor bin_lengths = bin_edges.diff(); + hist.div_(bin_lengths); } } @@ -180,14 +182,20 @@ void histogramdd_linear_kernel( Tensor& out_bin_edges) { hist.fill_(0); - double leftmost_edge, rightmost_edge; - if (!range.has_value()) { + // default range for empty input + double leftmost_edge = 0., rightmost_edge = 1.; + if (range.has_value()) { + leftmost_edge = range.value()[0]; + rightmost_edge = range.value()[1]; + } else if (self.numel() > 0) { auto extrema = at::aminmax(self); leftmost_edge = std::get<0>(extrema).item(); rightmost_edge = std::get<1>(extrema).item(); - } else { - leftmost_edge = range.value()[0]; - rightmost_edge = range.value()[1]; + } + + if (leftmost_edge == rightmost_edge) { + leftmost_edge -= 0.5; + rightmost_edge += 0.5; } at::linspace_out(out_bin_edges, leftmost_edge, rightmost_edge, bin_ct + 1); @@ -206,6 +214,8 @@ void histogramdd_linear_kernel( if (density) { const auto hist_sum = hist.sum(); hist.div_(hist_sum); + Tensor bin_lengths = bin_edges.diff(); + hist.div_(bin_lengths); } } diff --git a/test/xpu/extended/test_ops_xpu.py b/test/xpu/extended/test_ops_xpu.py index 792c7c569..1d9d3df6d 100644 --- a/test/xpu/extended/test_ops_xpu.py +++ b/test/xpu/extended/test_ops_xpu.py @@ -81,6 +81,12 @@ def test_compare_cpu(self, device, dtype, op): self.proxy = Namespace.TestCommonProxy() test_common_test_fn = get_wrapped_fn(Namespace.TestCommonProxy.test_compare_cpu) test_common_test_fn(self.proxy, device, dtype, op) + # for CUDA doesn't support operators + elif (op.name in ["histogram",]): + if dtype in op.dtypes: + self.proxy = Namespace.TestCommonProxy() + test_common_test_fn = get_wrapped_fn(Namespace.TestCommonProxy.test_compare_cpu) + test_common_test_fn(self.proxy, device, dtype, op) else: pytest.skip(f"{op.name} has not supported {dtype} yet both for cpu and xpu") From 1b74ad137466bf50774bf5a29ede89fa7d819adb Mon Sep 17 00:00:00 2001 From: "Ma, Jing1" Date: Fri, 12 Jul 2024 08:49:21 +0000 Subject: [PATCH 04/10] fix bug Signed-off-by: Ma, Jing1 --- src/ATen/native/xpu/sycl/HistogramddKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp index 36c92b7f0..7f7ba9ede 100644 --- a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -214,7 +214,7 @@ void histogramdd_linear_kernel( if (density) { const auto hist_sum = hist.sum(); hist.div_(hist_sum); - Tensor bin_lengths = bin_edges.diff(); + Tensor bin_lengths = out_bin_edges.diff(); hist.div_(bin_lengths); } } From 20bb9b33f9e8fcfd3c102fa240e2070595b7ed55 Mon Sep 17 00:00:00 2001 From: "Ma, Jing1" Date: Wed, 17 Jul 2024 06:46:51 +0000 Subject: [PATCH 05/10] Add skip cases Signed-off-by: Ma, Jing1 --- test/xpu/extended/run_test_with_skip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index 6c8968510..c705713cb 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -167,6 +167,10 @@ "test_compare_cpu_isin_xpu", "test_operator_isin_xpu_float32", "test_view_replay_isin_xpu_float32", + + # test case doesn't make sense, will file an issue to track it. + "test_compare_cpu_histogram_xpu_float32", + "test_compare_cpu_histogram_xpu_float64", ) From 399a7af3d18d06c0581bbb6d0642b507a96ca788 Mon Sep 17 00:00:00 2001 From: "Ma, Jing1" Date: Thu, 18 Jul 2024 01:12:53 +0000 Subject: [PATCH 06/10] add skip case Signed-off-by: Ma, Jing1 --- test/xpu/run_test_with_skip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index e1dc4788a..92b84e360 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -796,6 +796,8 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_noncontiguous_samples_nn_functional_local_response_norm_xpu_int64", # torch.complex32 - "sinh_cpu" not implemented for 'ComplexHalf' "test_dtypes_cosh_xpu", + # CUDA doesn't support operator + "test_dtypes_histogram_xpu", ) res += launch_test("test_ops_xpu.py", skip_list) From 967cba89a5b793bdb76ca6c61f3759c3a45b2232 Mon Sep 17 00:00:00 2001 From: majing Date: Mon, 22 Jul 2024 02:44:43 +0000 Subject: [PATCH 07/10] add skip cases missing by rebase Signed-off-by: majing --- test/xpu/extended/run_test_with_skip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index d06ed3be2..912c388b4 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -167,6 +167,10 @@ "test_compare_cpu_std_mean_xpu_bfloat16", "test_compare_cpu_sub_xpu_float16", "test_compare_cpu_var_mean_xpu_bfloat16", + + # test case doesn't make sense, will file an issue to track it. + "test_compare_cpu_histogram_xpu_float32", + "test_compare_cpu_histogram_xpu_float64", ) From 0063a0016e709e3bc8fc55456e2c2e910fc1b844 Mon Sep 17 00:00:00 2001 From: majing Date: Mon, 22 Jul 2024 09:53:12 +0000 Subject: [PATCH 08/10] fixed review comments Signed-off-by: majing --- src/ATen/native/xpu/Histogram.cpp | 2 -- src/ATen/native/xpu/sycl/HistogramddKernels.cpp | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/Histogram.cpp b/src/ATen/native/xpu/Histogram.cpp index 915de051c..de49952b2 100644 --- a/src/ATen/native/xpu/Histogram.cpp +++ b/src/ATen/native/xpu/Histogram.cpp @@ -122,7 +122,6 @@ static Tensor& histogramdd_out( bool density, Tensor& hist, Tensor& bin_edges) { - globalContext().alertNotDeterministic("histogram_bin_xpu"); histogramdd_check_inputs(self, bins, weight); histogramdd_prepare_out(self, bins, hist, bin_edges); @@ -168,7 +167,6 @@ std::tuple XPUNativeFunctions::histogram_out( bool density, Tensor& hist, Tensor& bin_edges) { - globalContext().alertNotDeterministic("histogram_bin_count_xpu"); Tensor reshaped_self = self.reshape({self.numel()}); std::optional reshaped_weight = weight.has_value() ? weight.value().reshape({weight.value().numel()}) diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp index 7f7ba9ede..be888e4b4 100644 --- a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -151,6 +151,7 @@ void histogramdd_kernel( bool density, Tensor& hist, const Tensor& bin_edges_) { + globalContext().alertNotDeterministic("histogramdd_kernel_xpu"); hist.fill_(0); Tensor bin_edges = bin_edges_.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND2( @@ -180,6 +181,7 @@ void histogramdd_linear_kernel( bool density, Tensor& hist, Tensor& out_bin_edges) { + globalContext().alertNotDeterministic("histogramdd_linear_kernel_xpu"); hist.fill_(0); // default range for empty input From 963cd91eb29186da71bd3b9e4894110922b8f72f Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Mon, 22 Jul 2024 18:33:19 +0800 Subject: [PATCH 09/10] Update test/xpu/extended/run_test_with_skip.py --- test/xpu/extended/run_test_with_skip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index ec52ecf8f..ca72e2015 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -164,6 +164,7 @@ "test_compare_cpu_var_mean_xpu_bfloat16", # test case doesn't make sense, will file an issue to track it. + # https://github.com/pytorch/pytorch/issues/130916 "test_compare_cpu_histogram_xpu_float32", "test_compare_cpu_histogram_xpu_float64", ) From 05bef8aad82134bcc5c3906b9af075a461bb96bc Mon Sep 17 00:00:00 2001 From: majing Date: Wed, 24 Jul 2024 02:04:57 +0000 Subject: [PATCH 10/10] updated comments Signed-off-by: majing --- test/xpu/run_test_with_skip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 29e0e7df5..96cbe4c1a 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -783,7 +783,8 @@ def launch_test(test_case, skip_list=None, exe_list=None): # torch.complex32 - "sinh_cpu" not implemented for 'ComplexHalf' "test_dtypes_cosh_xpu", - # CUDA doesn't support operator + # implemented aten::histogram to align MPS operators coverage, CUDA doesn't support + # but test_dtypes infrastructure leverage CUDA supported datatypes "test_dtypes_histogram_xpu", # The following dtypes worked in forward but are not listed by the OpInfo: {torch.float16}.