This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
Copy pathtensor_util-inl.h
92 lines (84 loc) · 2.88 KB
/
tensor_util-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tensor_util-inl.h
* \brief commonly utilized tensor operator CPU kernels
*/
#ifndef MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_H_
#define MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_H_
#include <mxnet/base.h>
#include <mxnet/operator.h>
namespace mxnet {
namespace op {
/*!
* \brief kernel to flag indices that appear in row_idx array with 1.
*/
struct MarkRowFlgKernel {
/*!
* \brief
* \param tid global thread id
* \param row_flg flag array for indices
* \param row_idx row index array storing indices of rows
*/
template <typename IType, typename DType>
MSHADOW_XINLINE static void Map(int tid, DType* row_flg, const IType* row_idx) {
nnvm::dim_t idx = static_cast<nnvm::dim_t>(row_idx[tid]);
row_flg[idx] = 1;
}
};
/*!
* \brief kernel for filling the row index array of an rsp tensor.
* Parallelized by tensor rows: 1 thread/row
*/
struct FillRspRowIdxKernel {
/*!
* \brief
* \param tid global thread id
* \param row_idx row index array to store indices of non-zero rows
* \param row_flg_sum inclusive prefix sum array over 0/1 marked row flag array
* \param num_rows rsp tensor number of rows (shape)
*/
template <typename RType>
MSHADOW_XINLINE static void Map(int tid,
RType* row_idx,
const nnvm::dim_t* row_flg_sum,
const nnvm::dim_t num_rows) {
if (tid < num_rows) {
nnvm::dim_t prev = (tid == 0) ? 0 : row_flg_sum[tid - 1];
if (row_flg_sum[tid] > prev) {
row_idx[prev] = static_cast<RType>(tid);
}
}
}
};
/*
* \brief the kernel to generate a lookup table for positions of row ids
* \param i thread id
* \param out output table
* \param data the input row id in sorted order
*/
struct MarkLookupTable {
template <typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) {
out[static_cast<nnvm::dim_t>(data[i])] = i;
}
};
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_H_