-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhelper_fusedR.h
157 lines (148 loc) · 6.46 KB
/
helper_fusedR.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/**
*
* OHIO STATE UNIVERSITY SOFTWARE DISTRIBUTION LICENSE
*
* Parallel CCD++ on GPU (the “Software”) Copyright (c) 2017, The Ohio State
* University. All rights reserved.
*
* The Software is available for download and use subject to the terms and
* conditions of this License. Access or use of the Software constitutes acceptance
* and agreement to the terms and conditions of this License. Redistribution and
* use of the Software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the capitalized paragraph below.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the capitalized paragraph below in the documentation
* and/or other materials provided with the distribution.
*
* 3. The names of Ohio State University, or its faculty, staff or students may not
* be used to endorse or promote products derived from the Software without
* specific prior written permission.
*
* This software was produced with support from the National Science Foundation
* (NSF) through Award 1629548. Nothing in this work should be construed as
* reflecting the official policy or position of the Defense Department, the United
* States government, Ohio State University.
*
* THIS SOFTWARE HAS BEEN APPROVED FOR PUBLIC RELEASE, UNLIMITED DISTRIBUTION. THE
* SOFTWARE IS PROVIDED “AS IS” AND WITHOUT ANY EXPRESS, IMPLIED OR STATUTORY
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, WARRANTIES OF ACCURACY, COMPLETENESS,
* NONINFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. ACCESS OR USE OF THE SOFTWARE IS ENTIRELY AT THE USER’S RISK. IN
* NO EVENT SHALL OHIO STATE UNIVERSITY OR ITS FACULTY, STAFF OR STUDENTS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
* TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE SOFTWARE
* USER SHALL INDEMNIFY, DEFEND AND HOLD HARMLESS OHIO STATE UNIVERSITY AND ITS
* FACULTY, STAFF AND STUDENTS FROM ANY AND ALL CLAIMS, ACTIONS, DAMAGES, LOSSES,
* LIABILITIES, COSTS AND EXPENSES, INCLUDING ATTORNEYS’ FEES AND COURT COSTS,
* DIRECTLY OR INDIRECTLY ARISING OUT OF OR IN CONNECTION WITH ACCESS OR USE OF THE
* SOFTWARE.
*
*/
/**
*
* Author:
* Israt ([email protected])
*
* Contacts:
* Israt ([email protected])
* Aravind Sukumaran-Rajam ([email protected])
* P. (Saday) Sadayappan ([email protected])
*
*/
#include "cuda_fusedR.h"
#include "utils_extra.hpp"
#include <cmath>
const int maxThreadsPerBlock = 1024;
int BLOCKSIZE = 128;
cudaStream_t stream[10 + 1]; //hard coded
dim3 block(BLOCKSIZE, 1, 1), grid(1, 1, 1);
void create_stream() {
for (int i = 0; i < NUM_THRDS; i++) {
cudaStreamCreate(&(stream[i]));
}
cudaStreamCreate(&(stream[NUM_THRDS]));
}
template<unsigned LB, unsigned UB>
struct RANK_LOOP {
RANK_LOOP() = delete;
RANK_LOOP(int& sum, const int * __restrict__ d_R_colPtr, int *d_row_lim,
unsigned *d_R_rowIdx, DTYPE *d_R_val, const DTYPE *d_Wt,
const DTYPE *d_Ht, int m, int n, bool add, int *rowGroupPtr,
int *count, DTYPE lambda, DTYPE *d_gArrV, DTYPE *d_hArrV,
DTYPE *v_new, DTYPE *Wt_p, DTYPE *Ht_p, int t) {
static_assert(LB<UB,"Lower Bound should be less than Upper bound");
if (count[LB] > 0) {
constexpr unsigned BLOCKSIZE_V2 = 128;
constexpr unsigned POWER = TMP_power<2, LB>::value;
grid.x = (POWER * count[LB] + BLOCKSIZE_V2 - 1) / BLOCKSIZE_V2;
updateR_gen<false, POWER, LB> <<<grid, BLOCKSIZE_V2, 0, stream[LB]>>>(
d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val, d_Wt, d_Ht, m,
n, add, rowGroupPtr + sum, count[LB], lambda, d_gArrV,
d_hArrV, v_new, Wt_p, Ht_p, t);
}
sum += count[LB];
RANK_LOOP<LB + 1, UB>(sum, d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val,
d_Wt, d_Ht, m, n, add, rowGroupPtr, count, lambda, d_gArrV,
d_hArrV, v_new, Wt_p, Ht_p, t);
}
};
template<unsigned LB>
struct RANK_LOOP<LB, LB> {
RANK_LOOP() = delete;
RANK_LOOP(int& sum, const int * __restrict__ d_R_colPtr, int *d_row_lim,
unsigned *d_R_rowIdx, DTYPE *d_R_val, const DTYPE *d_Wt,
const DTYPE *d_Ht, int m, int n, bool add, int *rowGroupPtr,
int *count, DTYPE lambda, DTYPE *d_gArrV, DTYPE *d_hArrV,
DTYPE *v_new, DTYPE *Wt_p, DTYPE *Ht_p, int t) {
//do nothing
}
};
void helper_UpdateR(int *d_R_colPtr, int *d_row_lim, unsigned *d_R_rowIdx,
DTYPE *d_R_val, DTYPE *d_Wt, DTYPE *d_Ht, int m, int n, bool add,
int *rowGroupPtr, int *count, DTYPE lambda, DTYPE *d_gArrV,
DTYPE *d_hArrV, DTYPE *v_new, DTYPE *Wt_p, DTYPE *Ht_p, int t) {
int sum = 0;
//loop from 0 to 5
RANK_LOOP<0, 6>(sum, d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val, d_Wt, d_Ht,
m, n, add, rowGroupPtr, count, lambda, d_gArrV, d_hArrV, v_new,
Wt_p, Ht_p, t);
if (count[6] > 0) {
grid.x = (64 * count[6] + BLOCKSIZE - 1) / BLOCKSIZE;
updateR_7<false> <<<grid, block, 2 * block.x / 32 * sizeof(DTYPE),
stream[6]>>>(d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val, d_Wt,
d_Ht, m, n, add, rowGroupPtr + sum, count[6], lambda, d_gArrV,
d_hArrV, v_new, Wt_p, Ht_p, t);
}
sum += count[6];
if (count[7] > 0) {
grid.x = (64 * count[7] + BLOCKSIZE - 1) / BLOCKSIZE;
updateR_7<false> <<<grid, block, 2 * block.x / 32 * sizeof(DTYPE),
stream[7]>>>(d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val, d_Wt,
d_Ht, m, n, add, rowGroupPtr + sum, count[7], lambda, d_gArrV,
d_hArrV, v_new, Wt_p, Ht_p, t);
}
sum += count[7];
if (count[8] > 0) {
grid.x = (64 * count[8] + BLOCKSIZE - 1) / BLOCKSIZE;
updateR_7<false> <<<grid, block, 2 * block.x / 32 * sizeof(DTYPE),
stream[8]>>>(d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val, d_Wt,
d_Ht, m, n, add, rowGroupPtr + sum, count[8], lambda, d_gArrV,
d_hArrV, v_new, Wt_p, Ht_p, t);
}
sum += count[8];
if (count[9] > 0) {
grid.x = (64 * count[9] + BLOCKSIZE - 1) / BLOCKSIZE;
updateR_7<false> <<<grid, block, 2 * block.x / 32 * sizeof(DTYPE),
stream[9]>>>(d_R_colPtr, d_row_lim, d_R_rowIdx, d_R_val, d_Wt,
d_Ht, m, n, add, rowGroupPtr + sum, count[9], lambda, d_gArrV,
d_hArrV, v_new, Wt_p, Ht_p, t);
}
}