diff --git a/src/ATen/native/xpu/Histogram.cpp b/src/ATen/native/xpu/Histogram.cpp new file mode 100644 index 000000000..de49952b2 --- /dev/null +++ b/src/ATen/native/xpu/Histogram.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include + +namespace at { + +/* Checks properties of input tensors input, bins, and weight. + */ +void histogramdd_check_inputs( + const Tensor& input, + const Tensor& bins, + const std::optional& weight) { + if (weight.has_value()) { + TORCH_CHECK( + weight->device() == input.device(), + "weight and input need to be on the same device.") + } + auto input_dtype = input.dtype(); + auto bins_dtype = bins.dtype(); + TORCH_CHECK( + input_dtype == bins_dtype, + "torch.histogramdd: input tensor and bins tensors should", + " have the same dtype, but got input with dtype ", + input_dtype, + " and bins with dtype ", + bins_dtype); + + const int64_t bins_dim = bins.dim(); + TORCH_CHECK( + bins_dim == 1, + "torch.histogramdd: bins tensor should have one dimension,", + " but got ", + bins_dim, + " dimensions in the bin tensor"); + + const int64_t numel = bins.numel(); + TORCH_CHECK( + numel > 0, + "torch.histogramdd: bins tensor should have at least 1 element,", + " but got ", + numel, + " elements in the bin tensor"); + + if (weight.has_value()) { + TORCH_CHECK( + input.dtype() == weight.value().dtype(), + "torch.histogramdd: if weight tensor is provided, ", + "input tensor and weight tensor should have the same dtype, ", + "but got input(", + input.dtype(), + ")", + ", and weight(", + weight.value().dtype(), + ")"); + + auto input_sizes = input.sizes().vec(); + + auto weight_sizes = weight.value().sizes().vec(); + if (weight_sizes.empty()) { + // correctly handle scalars + weight_sizes = {1}; + } + + TORCH_CHECK( + input_sizes == weight_sizes, + "torch.histogramdd: if weight tensor is provided it should have", + " the same shape as the input tensor excluding its innermost ", + "dimension, but got input with shape ", + input.sizes(), + " and weight ", + "with shape ", + weight.value().sizes()); + } +} + +/* Checks properties of output tensors hist and bin_edges, then resizes them. + */ +void histogramdd_prepare_out( + const Tensor& input, + int64_t bin_ct, + const Tensor& hist, + const Tensor& bin_edges) { + TORCH_CHECK( + input.dtype() == hist.dtype(), + "torch.histogram: input tensor and hist tensor should", + " have the same dtype, but got input ", + input.dtype(), + " and hist ", + hist.dtype()); + + TORCH_CHECK( + input.dtype() == bin_edges.dtype(), + "torch.histogram: input tensor and bin_edges tensor should", + " have the same dtype, but got input ", + input.dtype(), + " and bin_edges ", + bin_edges.dtype()); + + TORCH_CHECK( + bin_ct > 0, "torch.histogram(): bins must be > 0, but got ", bin_ct); + + at::native::resize_output(bin_edges, {bin_ct + 1}); + + at::native::resize_output(hist, {bin_ct}); +} + +void histogramdd_prepare_out( + const Tensor& input, + const Tensor& bins, + const Tensor& hist, + const Tensor& bin_edges) { + int64_t bin_ct = bins.numel() - 1; + histogramdd_prepare_out(input, bin_ct, hist, bin_edges); +} + +static Tensor& histogramdd_out( + const Tensor& self, + const Tensor& bins, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& bin_edges) { + histogramdd_check_inputs(self, bins, weight); + histogramdd_prepare_out(self, bins, hist, bin_edges); + + bin_edges.copy_(bins); + + at::native::xpu::histogramdd_kernel(self, weight, density, hist, bin_edges); + return hist; +} + +std::tuple XPUNativeFunctions::histogram_out( + const Tensor& self, + const Tensor& bins, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& bin_edges) { + Tensor reshaped_self = self.reshape({self.numel()}); + std::optional reshaped_weight = weight.has_value() + ? weight.value().reshape({weight.value().numel()}) + : weight; + + histogramdd_out( + reshaped_self, bins, reshaped_weight, density, hist, bin_edges); + + return std::forward_as_tuple(hist, bin_edges); +} + +std::tuple XPUNativeFunctions::histogram( + const Tensor& self, + const Tensor& bins, + const std::optional& weight, + bool density) { + Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous); + Tensor bin_edges = at::empty({0}, bins.options(), MemoryFormat::Contiguous); + return histogram_out(self, bins, weight, density, hist, bin_edges); +} + +std::tuple XPUNativeFunctions::histogram_out( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& bin_edges) { + Tensor reshaped_self = self.reshape({self.numel()}); + std::optional reshaped_weight = weight.has_value() + ? weight.value().reshape({weight.value().numel()}) + : weight; + + histogramdd_prepare_out(reshaped_self, bin_ct, hist, bin_edges); + histogramdd_check_inputs(reshaped_self, bin_edges, reshaped_weight); + + at::native::xpu::histogramdd_linear_kernel( + reshaped_self, bin_ct, range, reshaped_weight, density, hist, bin_edges); + return std::forward_as_tuple(hist, bin_edges); +} + +std::tuple XPUNativeFunctions::histogram( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density) { + Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous); + Tensor bin_edges_out = at::empty({0}, self.options()); + return histogram_out( + self, bin_ct, range, weight, density, hist, bin_edges_out); +} + +} // namespace at \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/HistogramKernels.h b/src/ATen/native/xpu/sycl/HistogramKernels.h new file mode 100644 index 000000000..da153186a --- /dev/null +++ b/src/ATen/native/xpu/sycl/HistogramKernels.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void histogramdd_kernel( + const Tensor& self, + const std::optional& weight, + bool density, + Tensor& hist, + const Tensor& bin_edges); + +void histogramdd_linear_kernel( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& out_bin_edges); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp new file mode 100644 index 000000000..be888e4b4 --- /dev/null +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -0,0 +1,227 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native::xpu { + +template +struct HistogramddKernelFunctor { + void operator()(sycl::nd_item<1> item_id) const { + int64_t wi_id = item_id.get_global_id(); + if (wi_id < input_size_ * bin_size_) { + int64_t ele_idx = wi_id / bin_size_; + int64_t bin_idx = wi_id % bin_size_; + + // [left, right) + if (input_[ele_idx] >= bin_edges_[bin_idx] && + input_[ele_idx] < bin_edges_[bin_idx + 1]) { + scalar_t value = weight_ ? weight_[ele_idx] : (scalar_t)1; + atomicAdd((sycl_global_ptr)(hist_ + bin_idx), value); + return; + } + + // For last bin, [left, right] + if (bin_idx == 0 && input_[ele_idx] == bin_edges_[bin_size_]) { + scalar_t value = weight_ ? weight_[ele_idx] : (scalar_t)1; + atomicAdd((sycl_global_ptr)(hist_ + bin_size_ - 1), value); + } + } + } + + HistogramddKernelFunctor( + const scalar_t* input, + const scalar_t* bin_edges, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size) + : input_(input), + bin_edges_(bin_edges), + hist_(hist), + weight_(weight), + input_size_(input_size), + bin_size_(bin_size) {} + + private: + const scalar_t* input_; + const scalar_t* bin_edges_; + scalar_t* hist_; + const scalar_t* weight_; + int64_t input_size_; + int64_t bin_size_; +}; + +// For one dimension case +template +void histogramdd_template( + const scalar_t* input, + const scalar_t* bin_edges, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size) { + HistogramddKernelFunctor kfn( + input, bin_edges, hist, weight, input_size, bin_size); + const int64_t work_group_size = syclMaxWorkGroupSize(kfn); + const int64_t num_wg = + (input_size * bin_size + work_group_size - 1) / work_group_size; + sycl_kernel_submit( + num_wg * work_group_size, work_group_size, getCurrentSYCLQueue(), kfn); +} + +template +struct HistogramddLinearKernelFunctor { + void operator()(sycl::nd_item<1> item_id) const { + int64_t wi_id = item_id.get_global_id(); + if (wi_id < input_size_) { + scalar_t i_value = input_[wi_id]; + if (i_value >= leftmost_edge_ && i_value <= rightmost_edge_) { + int64_t bin = + (int64_t)(((i_value - leftmost_edge_)) * bin_size_ / (rightmost_edge_ - leftmost_edge_)); + if (bin == bin_size_) + bin -= 1; + scalar_t value = weight_ ? weight_[wi_id] : (scalar_t)1; + atomicAdd((sycl_global_ptr)(hist_ + bin), value); + } + } + } + + HistogramddLinearKernelFunctor( + const scalar_t* input, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size, + double leftmost_edge, + double rightmost_edge) + : input_(input), + hist_(hist), + weight_(weight), + input_size_(input_size), + bin_size_(bin_size), + leftmost_edge_(leftmost_edge), + rightmost_edge_(rightmost_edge) {} + + private: + const scalar_t* input_; + scalar_t* hist_; + const scalar_t* weight_; + int64_t input_size_; + int64_t bin_size_; + double leftmost_edge_; + double rightmost_edge_; +}; + +// For one dimension case +template +void histogramdd_linear_template( + const scalar_t* input, + scalar_t* hist, + const scalar_t* weight, + int64_t input_size, + int64_t bin_size, + double leftmost_edge, + double rightmost_edge) { + HistogramddLinearKernelFunctor kfn( + input, hist, weight, input_size, bin_size, leftmost_edge, rightmost_edge); + const int64_t work_group_size = syclMaxWorkGroupSize(kfn); + const int64_t num_wg = (input_size + work_group_size - 1) / work_group_size; + sycl_kernel_submit( + num_wg * work_group_size, work_group_size, getCurrentSYCLQueue(), kfn); +} + +void histogramdd_kernel( + const Tensor& self, + const std::optional& weight, + bool density, + Tensor& hist, + const Tensor& bin_edges_) { + globalContext().alertNotDeterministic("histogramdd_kernel_xpu"); + hist.fill_(0); + Tensor bin_edges = bin_edges_.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, self.scalar_type(), "histogram_xpu", [&]() { + histogramdd_template( + self.data_ptr(), + bin_edges.data_ptr(), + hist.data_ptr(), + weight.has_value() ? weight->data_ptr() : nullptr, + self.numel(), + bin_edges.numel() - 1); + }); + + if (density) { + const auto hist_sum = hist.sum(); + hist.div_(hist_sum); + Tensor bin_lengths = bin_edges.diff(); + hist.div_(bin_lengths); + } +} + +void histogramdd_linear_kernel( + const Tensor& self, + int64_t bin_ct, + std::optional> range, + const std::optional& weight, + bool density, + Tensor& hist, + Tensor& out_bin_edges) { + globalContext().alertNotDeterministic("histogramdd_linear_kernel_xpu"); + hist.fill_(0); + + // default range for empty input + double leftmost_edge = 0., rightmost_edge = 1.; + if (range.has_value()) { + leftmost_edge = range.value()[0]; + rightmost_edge = range.value()[1]; + } else if (self.numel() > 0) { + auto extrema = at::aminmax(self); + leftmost_edge = std::get<0>(extrema).item(); + rightmost_edge = std::get<1>(extrema).item(); + } + + if (leftmost_edge == rightmost_edge) { + leftmost_edge -= 0.5; + rightmost_edge += 0.5; + } + + at::linspace_out(out_bin_edges, leftmost_edge, rightmost_edge, bin_ct + 1); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, self.scalar_type(), "histogram_linear_xpu", [&]() { + histogramdd_linear_template( + self.data_ptr(), + hist.data_ptr(), + weight.has_value() ? weight->data_ptr() : nullptr, + self.numel(), + bin_ct, + leftmost_edge, + rightmost_edge); + }); + + if (density) { + const auto hist_sum = hist.sum(); + hist.div_(hist_sum); + Tensor bin_lengths = out_bin_edges.diff(); + hist.div_(bin_lengths); + } +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index 76790d817..a05d47e05 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -163,6 +163,11 @@ "test_compare_cpu_std_mean_xpu_bfloat16", "test_compare_cpu_sub_xpu_float16", "test_compare_cpu_var_mean_xpu_bfloat16", + + # test case doesn't make sense, will file an issue to track it. + # https://github.com/pytorch/pytorch/issues/130916 + "test_compare_cpu_histogram_xpu_float32", + "test_compare_cpu_histogram_xpu_float64", ) diff --git a/test/xpu/extended/test_ops_xpu.py b/test/xpu/extended/test_ops_xpu.py index 792c7c569..1d9d3df6d 100644 --- a/test/xpu/extended/test_ops_xpu.py +++ b/test/xpu/extended/test_ops_xpu.py @@ -81,6 +81,12 @@ def test_compare_cpu(self, device, dtype, op): self.proxy = Namespace.TestCommonProxy() test_common_test_fn = get_wrapped_fn(Namespace.TestCommonProxy.test_compare_cpu) test_common_test_fn(self.proxy, device, dtype, op) + # for CUDA doesn't support operators + elif (op.name in ["histogram",]): + if dtype in op.dtypes: + self.proxy = Namespace.TestCommonProxy() + test_common_test_fn = get_wrapped_fn(Namespace.TestCommonProxy.test_compare_cpu) + test_common_test_fn(self.proxy, device, dtype, op) else: pytest.skip(f"{op.name} has not supported {dtype} yet both for cpu and xpu") diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 0e00da513..96cbe4c1a 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -783,6 +783,10 @@ def launch_test(test_case, skip_list=None, exe_list=None): # torch.complex32 - "sinh_cpu" not implemented for 'ComplexHalf' "test_dtypes_cosh_xpu", + # implemented aten::histogram to align MPS operators coverage, CUDA doesn't support + # but test_dtypes infrastructure leverage CUDA supported datatypes + "test_dtypes_histogram_xpu", + # The following dtypes worked in forward but are not listed by the OpInfo: {torch.float16}. # Align with CPU implementation since, # 1. most cases of nextafter require Half dtype. diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 96e43058a..6d989fa14 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -209,6 +209,7 @@ "aminmax", "argmin", "conj_physical", + "histogram", "repeat_interleave", "fmax", "fmin", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 48ddc4c3f..224c0e7d8 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -660,6 +660,10 @@ supported: - ceil - ceil_ - ceil.out + - histogram.bins_tensor + - histogram.bins_tensor_out + - histogram.bin_ct + - histogram.bin_ct_out - repeat_interleave.Tensor - norm.ScalarOpt_dim_dtype - norm.dtype_out