-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathOffsetCalculator.hpp
114 lines (99 loc) · 4.04 KB
/
OffsetCalculator.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#pragma once
#include <array>
#include <cstdint>
#include <type_traits>
#include "IntegerDivider.hpp"
/// OffsetCalculator calculates the offset in bytes of a linear index for NARGS
/// operands that share the same shape, but may have different strides.
namespace porting {
// XXX: Nvidia is 25, adjust it to see what the final results
constexpr int MAX_DIMS = 12;
template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
struct OffsetCalculator {
// We allow having negative strides to implement some operations like torch.flip
using stride_t = std::conditional_t<signed_strides,
std::make_signed_t<index_t>,
index_t>;
// The offset for each argument. Wrapper around fixed-size array.
// On CUDA, zero sized array is not allowed, so when we are handling nullary
// operators, we need to create a size 1 offset to avoid compiler failure.
// This size 1 offset is just a placeholder, and we will not use it.
using offset_type = std::array<stride_t, std::max<int>(NARGS, 1)>;
// if element_sizes is nullptr, then the strides will be in bytes, otherwise
// the strides will be in # of elements.
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
assert(("tensor has too many dims"&& dims <= MAX_DIMS));
for (int i = 0; i < dims; ++i) {
sizes_[i] = IntDivider<index_t>(sizes[i]);
for (int arg = 0; arg < NARGS; arg++) {
int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
strides_[i][arg] = strides[arg][i] / element_size;
}
}
}
offset_type get(index_t linear_idx) const {
offset_type offsets;
# pragma unroll (NARGS)
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = 0;
}
# pragma unroll (MAX_DIMS)
for (int dim = 0; dim < MAX_DIMS; ++dim) {
if (dim == dims) {
break;
}
auto divmod = sizes_[dim].divmod(linear_idx);
linear_idx = divmod.div;
# pragma unroll (NARGS)
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] += divmod.mod * strides_[dim][arg];
}
}
return offsets;
}
int dims;
IntDivider<index_t> sizes_[MAX_DIMS];
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
};
template <int NARGS, typename index_t = uint32_t>
struct TrivialOffsetCalculator {
// The offset for each argument. Wrapper around fixed-size array.
// The offsets are in # of elements, not in bytes.
// On CUDA, zero sized array is not allowed, so when we are handling nullary
// operators, we need to create a size 1 offset to avoid compiler failure.
// This size 1 offset is just a placeholder, and we will not use it.
using offset_type = std::array<index_t, std::max<int>(NARGS, 1)>;
offset_type get(index_t linear_idx) const {
offset_type offsets;
#pragma unroll (NARGS)
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = linear_idx;
}
return offsets;
}
};
// Make an OffsetCalculator with byte offsets
template<int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIterator& iter) {
assert(N <= iter.ntensors());
std::array<const int64_t*, N> strides;
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i).data();
}
return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
}
// Make an OffsetCalculator with element offsets
template<int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
const at::TensorIteratorBase& iter) {
assert(N <= iter.ntensors());
std::array<const int64_t*, N> strides;
std::array<int64_t, N> element_sizes;
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i).data();
element_sizes[i] = iter.element_size(i);
}
return OffsetCalculator<N, uint32_t, signed_strides>(
iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
}
}