-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into mengfeil/triton
- Loading branch information
Showing
34 changed files
with
3,351 additions
and
407 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#include <ATen/native/sparse/SparseStubs.h> | ||
#include <ATen/native/sparse/xpu/sycl/SparseCsrTensorMathKernels.h> | ||
#include <xpu/ATen/ops/_convert_indices_from_coo_to_csr_native.h> | ||
#include <xpu/ATen/ops/_convert_indices_from_csr_to_coo_native.h> | ||
|
||
namespace at::native { | ||
|
||
using namespace at::sparse; | ||
|
||
TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_xpu) | ||
(const Tensor& input, | ||
const int64_t size, | ||
const bool out_int32, | ||
const Tensor& result) { | ||
xpu::convert_indices_from_coo_to_csr_structured_kernel( | ||
input, size, out_int32, result); | ||
}; | ||
|
||
TORCH_IMPL_FUNC(_convert_indices_from_csr_to_coo_structured_xpu) | ||
(const Tensor& crow_indices, | ||
const Tensor& col_indices, | ||
const bool out_int32, | ||
const bool transpose, | ||
const Tensor& result) { | ||
xpu::convert_indices_from_csr_to_coo_structured_kernel( | ||
crow_indices, col_indices, out_int32, transpose, result); | ||
}; | ||
|
||
Tensor _sparse_csr_sum_xpu( | ||
const Tensor& input, | ||
IntArrayRef dims_to_sum, | ||
bool keepdim, | ||
std::optional<ScalarType> dtype) { | ||
return xpu::_sparse_csr_sum_xpu_kernel(input, dims_to_sum, keepdim, dtype); | ||
} | ||
|
||
Tensor _sparse_csr_prod_xpu( | ||
const Tensor& input, | ||
IntArrayRef dims_to_reduce, | ||
bool keepdim, | ||
std::optional<ScalarType> dtype) { | ||
return xpu::_sparse_csr_prod_xpu_kernel( | ||
input, dims_to_reduce, keepdim, dtype); | ||
} | ||
|
||
} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include <ATen/native/sparse/ParamUtils.h> | ||
#include <ATen/native/sparse/SparseStubs.h> | ||
#include <ATen/native/sparse/xpu/sycl/SparseSoftmaxKernels.h> | ||
|
||
namespace at::native { | ||
|
||
using namespace at::sparse; | ||
|
||
Tensor softmax_sparse_xpu( | ||
const Tensor& input_, | ||
const int64_t dim_, | ||
const bool half_to_float) { | ||
return xpu::softmax_sparse_xpu_kernel(input_, dim_, half_to_float); | ||
} | ||
|
||
Tensor log_softmax_sparse_xpu( | ||
const Tensor& input_, | ||
const int64_t dim_, | ||
const bool half_to_float) { | ||
return xpu::log_softmax_sparse_xpu_kernel(input_, dim_, half_to_float); | ||
} | ||
|
||
Tensor softmax_backward_sparse_xpu( | ||
const Tensor& grad_, | ||
const Tensor& output_, | ||
int64_t dim_, | ||
const Tensor& input_) { | ||
return xpu::softmax_backward_sparse_xpu_kernel(grad_, output_, dim_, input_); | ||
} | ||
|
||
Tensor log_softmax_backward_sparse_xpu( | ||
const Tensor& grad_, | ||
const Tensor& output_, | ||
int64_t dim_, | ||
const Tensor& input_) { | ||
return xpu::log_softmax_backward_sparse_xpu_kernel( | ||
grad_, output_, dim_, input_); | ||
} | ||
} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.