Skip to content

Commit

Permalink
Merge branch 'main' into majing/histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyuan14 authored Jul 24, 2024
2 parents 05bef8a + f716a58 commit c0795fe
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ static void cum_ops_meta(
maybe_wrap_dim(dim, self.dim());

ScalarType out_dtype;

if (result.defined()) {
out_dtype = dtype.value_or(result.scalar_type());
at::xpu::resize_out(
Expand Down
10 changes: 10 additions & 0 deletions src/ATen/native/xpu/Repeat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <ATen/ATen.h>
#include <ATen/native/xpu/sycl/RepeatKernel.h>
#include <ATen/xpu/XPUNativeFunctions.h>
namespace at {
Tensor XPUNativeFunctions::repeat_interleave(
const Tensor& repeats,
c10::optional<int64_t> output_size) {
return at::native::xpu::repeat_interleave_kernel(repeats, output_size);
}
} // namespace at
3 changes: 0 additions & 3 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"nanmedian",
"nanmedian.dim_values",
"nansum",
"norm.out",
"nextafter.out",
"ormqr",
"_pdist_backward",
"_pdist_forward",
Expand All @@ -250,7 +248,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"prod",
"prod.int_out",
"put_",
"repeat_interleave.Tensor",
"replication_pad1d_backward.grad_input",
"replication_pad1d.out",
"replication_pad2d_backward",
Expand Down
79 changes: 79 additions & 0 deletions src/ATen/native/xpu/sycl/RepeatKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <ATen/ATen.h>
#include <ATen/native/Repeat.h>
#include <ATen/native/xpu/sycl/RepeatKernel.h>
#include <comm/SYCLContext.h>
namespace at::native::xpu {
template <typename index_t>
struct RepeatInterleaveKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
auto rep_ptr = rep_data_;
auto cum_ptr = cum_data_;
auto res_ptr = res_data_;

for (int64_t i = item.get_global_id(0); i < size_;
i += item.get_global_range()[0]) {
int64_t end = cum_ptr[i];
int64_t repeat = rep_ptr[i];
int64_t start = end - repeat;
for (int64_t j = start; j < end; j++) {
res_ptr[j] = i;
}
}
}
RepeatInterleaveKernelFunctor(
const index_t* rep_data,
const int64_t* cum_data,
index_t* res_data,
int64_t size,
int64_t result_size)
: rep_data_(rep_data),
cum_data_(cum_data),
res_data_(res_data),
size_(size),
result_size_(result_size) {}

private:
const index_t* rep_data_;
const int64_t* cum_data_;
index_t* res_data_;
int64_t size_;
int64_t result_size_;
};

template <typename index_t>
static void compute_xpu(
const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
if (size == 0)
return;

auto kfn = RepeatInterleaveKernelFunctor<index_t>(
repeat_ptr,
cumsum_ptr,
result_ptr,
size,
result_size);

int64_t wg_size = syclMaxWorkGroupSize(kfn);
int64_t local_range = size < wg_size ? size : wg_size;
int64_t global_range = ((size + local_range - 1) / local_range) * local_range;

auto queue = getCurrentSYCLQueue();
sycl_kernel_submit(global_range, local_range, queue, kfn);
}

Tensor repeat_interleave_kernel(
const Tensor& repeat,
c10::optional<int64_t> output_size) {
Tensor output;

AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_xpu", [&] {
output = repeat_interleave_common<index_t, compute_xpu<index_t>>(
repeat, output_size);
});
return output;
}
} // namespace at::native::xpu
9 changes: 9 additions & 0 deletions src/ATen/native/xpu/sycl/RepeatKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once
#include <ATen/ATen.h>
namespace at::native::xpu {

Tensor repeat_interleave_kernel(
const Tensor& repeats,
c10::optional<int64_t> output_size);

} // namespace at::native::xpu
4 changes: 1 addition & 3 deletions src/ATen/native/xpu/sycl/SortingKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,7 @@ void segmented_sort_pairs_(
int num_elements) {
constexpr int scaling_coef = sizeof(key_t) * sizeof(value_t) >= 64
? 2
: 1; // Attempt to reduce register pressure. The result will be incorrect
// when using too many local variables (registers).
// https://github.com/intel/torch-xpu-ops/issues/626
: 1; // Attempt to reduce register pressure for performance.
if (num_elements > 4096 / scaling_coef) {
// Considering register pressure, we use a problem size of 4096 to delineate
// the boundary between single tile sort and group sort.
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
"argmin",
"conj_physical",
"histogram",
"repeat_interleave",
"fmax",
"fmin",
"floor",
Expand Down
1 change: 1 addition & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ supported:
- histogram.bins_tensor_out
- histogram.bin_ct
- histogram.bin_ct_out
- repeat_interleave.Tensor
- norm.ScalarOpt_dim_dtype
- norm.dtype_out
- norm.ScalarOpt_dim
Expand Down

0 comments on commit c0795fe

Please sign in to comment.