Skip to content

Commit

Permalink
Add aten::histogram and variant (#569)
Browse files Browse the repository at this point in the history
CPU/MPS only ops

---------

Signed-off-by: majing <[email protected]>
Signed-off-by: Ma, Jing1 <[email protected]>
Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
majing921201 and fengyuan14 authored Jul 25, 2024
1 parent e1e195b commit fcaa34c
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 0 deletions.
195 changes: 195 additions & 0 deletions src/ATen/native/xpu/Histogram.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#include <ATen/Context.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/sycl/HistogramKernels.h>
#include <ATen/xpu/XPUNativeFunctions.h>

namespace at {

/* Checks properties of input tensors input, bins, and weight.
*/
void histogramdd_check_inputs(
const Tensor& input,
const Tensor& bins,
const std::optional<Tensor>& weight) {
if (weight.has_value()) {
TORCH_CHECK(
weight->device() == input.device(),
"weight and input need to be on the same device.")
}
auto input_dtype = input.dtype();
auto bins_dtype = bins.dtype();
TORCH_CHECK(
input_dtype == bins_dtype,
"torch.histogramdd: input tensor and bins tensors should",
" have the same dtype, but got input with dtype ",
input_dtype,
" and bins with dtype ",
bins_dtype);

const int64_t bins_dim = bins.dim();
TORCH_CHECK(
bins_dim == 1,
"torch.histogramdd: bins tensor should have one dimension,",
" but got ",
bins_dim,
" dimensions in the bin tensor");

const int64_t numel = bins.numel();
TORCH_CHECK(
numel > 0,
"torch.histogramdd: bins tensor should have at least 1 element,",
" but got ",
numel,
" elements in the bin tensor");

if (weight.has_value()) {
TORCH_CHECK(
input.dtype() == weight.value().dtype(),
"torch.histogramdd: if weight tensor is provided, ",
"input tensor and weight tensor should have the same dtype, ",
"but got input(",
input.dtype(),
")",
", and weight(",
weight.value().dtype(),
")");

auto input_sizes = input.sizes().vec();

auto weight_sizes = weight.value().sizes().vec();
if (weight_sizes.empty()) {
// correctly handle scalars
weight_sizes = {1};
}

TORCH_CHECK(
input_sizes == weight_sizes,
"torch.histogramdd: if weight tensor is provided it should have",
" the same shape as the input tensor excluding its innermost ",
"dimension, but got input with shape ",
input.sizes(),
" and weight ",
"with shape ",
weight.value().sizes());
}
}

/* Checks properties of output tensors hist and bin_edges, then resizes them.
*/
void histogramdd_prepare_out(
const Tensor& input,
int64_t bin_ct,
const Tensor& hist,
const Tensor& bin_edges) {
TORCH_CHECK(
input.dtype() == hist.dtype(),
"torch.histogram: input tensor and hist tensor should",
" have the same dtype, but got input ",
input.dtype(),
" and hist ",
hist.dtype());

TORCH_CHECK(
input.dtype() == bin_edges.dtype(),
"torch.histogram: input tensor and bin_edges tensor should",
" have the same dtype, but got input ",
input.dtype(),
" and bin_edges ",
bin_edges.dtype());

TORCH_CHECK(
bin_ct > 0, "torch.histogram(): bins must be > 0, but got ", bin_ct);

at::native::resize_output(bin_edges, {bin_ct + 1});

at::native::resize_output(hist, {bin_ct});
}

void histogramdd_prepare_out(
const Tensor& input,
const Tensor& bins,
const Tensor& hist,
const Tensor& bin_edges) {
int64_t bin_ct = bins.numel() - 1;
histogramdd_prepare_out(input, bin_ct, hist, bin_edges);
}

static Tensor& histogramdd_out(
const Tensor& self,
const Tensor& bins,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& bin_edges) {
histogramdd_check_inputs(self, bins, weight);
histogramdd_prepare_out(self, bins, hist, bin_edges);

bin_edges.copy_(bins);

at::native::xpu::histogramdd_kernel(self, weight, density, hist, bin_edges);
return hist;
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::histogram_out(
const Tensor& self,
const Tensor& bins,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& bin_edges) {
Tensor reshaped_self = self.reshape({self.numel()});
std::optional<Tensor> reshaped_weight = weight.has_value()
? weight.value().reshape({weight.value().numel()})
: weight;

histogramdd_out(
reshaped_self, bins, reshaped_weight, density, hist, bin_edges);

return std::forward_as_tuple(hist, bin_edges);
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::histogram(
const Tensor& self,
const Tensor& bins,
const std::optional<Tensor>& weight,
bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
Tensor bin_edges = at::empty({0}, bins.options(), MemoryFormat::Contiguous);
return histogram_out(self, bins, weight, density, hist, bin_edges);
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::histogram_out(
const Tensor& self,
int64_t bin_ct,
std::optional<c10::ArrayRef<double>> range,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& bin_edges) {
Tensor reshaped_self = self.reshape({self.numel()});
std::optional<Tensor> reshaped_weight = weight.has_value()
? weight.value().reshape({weight.value().numel()})
: weight;

histogramdd_prepare_out(reshaped_self, bin_ct, hist, bin_edges);
histogramdd_check_inputs(reshaped_self, bin_edges, reshaped_weight);

at::native::xpu::histogramdd_linear_kernel(
reshaped_self, bin_ct, range, reshaped_weight, density, hist, bin_edges);
return std::forward_as_tuple(hist, bin_edges);
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::histogram(
const Tensor& self,
int64_t bin_ct,
std::optional<c10::ArrayRef<double>> range,
const std::optional<Tensor>& weight,
bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
Tensor bin_edges_out = at::empty({0}, self.options());
return histogram_out(
self, bin_ct, range, weight, density, hist, bin_edges_out);
}

} // namespace at
23 changes: 23 additions & 0 deletions src/ATen/native/xpu/sycl/HistogramKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include <ATen/ATen.h>

namespace at::native::xpu {

void histogramdd_kernel(
const Tensor& self,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
const Tensor& bin_edges);

void histogramdd_linear_kernel(
const Tensor& self,
int64_t bin_ct,
std::optional<c10::ArrayRef<double>> range,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& out_bin_edges);

} // namespace at::native::xpu
Loading

0 comments on commit fcaa34c

Please sign in to comment.