-
Notifications
You must be signed in to change notification settings - Fork 357
/
Copy pathfused_rope.cu
366 lines (336 loc) · 19.9 KB
/
fused_rope.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_rope.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
namespace transformer_engine {
template <typename scalar_t>
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block,
const int offset_block_dst, const int h, const int d,
const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
}
}
}
}
template <typename scalar_t>
__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block,
const int offset_block_dst, const int h, const int d,
const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]);
float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
// handle the tail
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
}
}
}
}
template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start;
if (t_id >= end) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
int s_id_for_freqs;
if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start;
if (t_id >= end) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
int s_id_for_freqs;
if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs,
scalar_t *input_grads, const int s, const int b, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d,
o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens,
const float *freqs, scalar_t *output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s,
const int b, const int h, const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), s, b, h, d,
d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const int cp_size, const int cp_rank, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), cp_size,
cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens,
const Tensor &freqs, Tensor *input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr),
cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d, stream););
}
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output),
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
NVTETensor input_grads, const int s, const int b, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), s, b, h, d, d2, stride_s, stride_b,
stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), cp_size, cp_rank, max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(
*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
}