From 3f1a261a23c9976922163e833e5dd381b7987683 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Tue, 3 May 2022 21:38:27 +0800 Subject: [PATCH 1/2] dlock without timeout --- energon/engine/bert_pipeline_wrapper.py | 4 ++-- energon/engine/gpt_pipeline_wrapper.py | 4 ++-- requirements.txt | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/energon/engine/bert_pipeline_wrapper.py b/energon/engine/bert_pipeline_wrapper.py index 16858da..eb50278 100644 --- a/energon/engine/bert_pipeline_wrapper.py +++ b/energon/engine/bert_pipeline_wrapper.py @@ -90,7 +90,7 @@ def run(self, key, inputs): self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) # different threads ask for a single lock - self.lock.acquire(timeout=3) + self.lock.acquire() sample, pipe_meta = self.pipe_msg_queue.top(self.key.val) self.key.addOne() @@ -128,4 +128,4 @@ def run(self, key, inputs): self.lock.release() return None - \ No newline at end of file + diff --git a/energon/engine/gpt_pipeline_wrapper.py b/energon/engine/gpt_pipeline_wrapper.py index cc1f208..c1e8bb5 100644 --- a/energon/engine/gpt_pipeline_wrapper.py +++ b/energon/engine/gpt_pipeline_wrapper.py @@ -101,7 +101,7 @@ def run(self, key, inputs): self.fill_meta_tensor(inputs, pipe_meta) self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) - self.lock.acquire(timeout=3) + self.lock.acquire() sample, pipe_meta = self.pipe_msg_queue.top(self.key.val) self.key.addOne() @@ -136,4 +136,4 @@ def run(self, key, inputs): self.lock.release() return None - \ No newline at end of file + diff --git a/requirements.txt b/requirements.txt index b483769..0c98398 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ psutil packaging fastapi uvicorn==0.14 +typer \ No newline at end of file From 0ebe9b5311d83b54ad7e3b443d96226af7c5dfe4 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Wed, 4 May 2022 14:14:30 +0800 Subject: [PATCH 2/2] padding rebuild kerenls from faster --- energon/kernel/__init__.py | 4 +- energon/kernel/cuda_native/__init__.py | 1 + .../csrc/transpose_pad_fusion_kernel.cu | 228 +++++++++++++++++- .../csrc/transpose_pad_fusion_wrapper.cpp | 141 ++++++++++- energon/kernel/cuda_native/transpose_pad.py | 37 +++ requirements.txt | 4 +- tests/test_kernel/test_ft_transpose_pad.py | 77 ++++++ 7 files changed, 480 insertions(+), 12 deletions(-) create mode 100644 tests/test_kernel/test_ft_transpose_pad.py diff --git a/energon/kernel/__init__.py b/energon/kernel/__init__.py index ddcc182..ad2057f 100644 --- a/energon/kernel/__init__.py +++ b/energon/kernel/__init__.py @@ -1,5 +1,7 @@ from .cuda_native import transpose_pad, transpose_depad, depad, scale_mask_softmax +from .cuda_native import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding __all__ = [ - "transpose_pad", "transpose_depad", "depad", "scale_mask_softmax" + "transpose_pad", "transpose_depad", "depad", "scale_mask_softmax", + "ft_build_padding_offsets", "ft_remove_padding", "ft_rebuild_padding", "ft_transpose_remove_padding", "ft_transpose_rebuild_padding" ] diff --git a/energon/kernel/cuda_native/__init__.py b/energon/kernel/cuda_native/__init__.py index 1706505..37a1076 100644 --- a/energon/kernel/cuda_native/__init__.py +++ b/energon/kernel/cuda_native/__init__.py @@ -1,3 +1,4 @@ from .transpose_pad import transpose_pad, transpose_depad, depad +from .transpose_pad import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding from .scale_mask_softmax import scale_mask_softmax from .layer_norm import MixedFusedLayerNorm as LayerNorm \ No newline at end of file diff --git a/energon/kernel/cuda_native/csrc/transpose_pad_fusion_kernel.cu b/energon/kernel/cuda_native/csrc/transpose_pad_fusion_kernel.cu index 341ae31..959c1b3 100644 --- a/energon/kernel/cuda_native/csrc/transpose_pad_fusion_kernel.cu +++ b/energon/kernel/cuda_native/csrc/transpose_pad_fusion_kernel.cu @@ -2,7 +2,9 @@ // transpose and padding/depadding fusion to reduce the memory move. #include +#include +// from turbo __global__ void transpose_depad_kernel(const float* src, const int batch_size, const int seq_len, const int64_t* seq_lens, @@ -85,4 +87,228 @@ void transpose_pad(const float* src, dim3 dimBlock(size_per_head); transpose_pad_kernel<<>>(src, batch_size, seq_len, seq_lens, head_num, size_per_head, dst); -} \ No newline at end of file +} + + +// from faster + +/* create offsets */ + +__global__ void build_sequence_length_padding_offset(const int* sequence_length, + const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset) +{ + // do cumulated sum + int total_seq_len = 0; + int cum_offset = 0; + int index = 0; + for(int i = 0; i < batch_size; i++) + { + const int seq_len = sequence_length[i]; + for(int j = 0; j < seq_len; j++) + { + tmp_mask_offset[index] = cum_offset; + index++; + } + cum_offset += max_seq_len - seq_len; + total_seq_len += seq_len; + } + valid_word_num[0] = total_seq_len; +} + +void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length, + const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset) +{ + build_sequence_length_padding_offset<<<1, 1>>>(sequence_length, + batch_size, max_seq_len, valid_word_num, tmp_mask_offset); +} + + +/* remove padding from embedding layer to transformer blocks */ +template +__global__ void remove_sequence_length_padding(const T* src, T* tgt, + const int* tmp_mask_offset, + int* mask_offset, + const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + mask_offset[bid] = tmp_mask_offset[bid]; + const int src_seq_id = bid + mask_offset[bid]; + const int tgt_seq_id = bid; + + for(int i = tid; i < n; i += blockDim.x) + { + tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + +template +void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* tmp_mask_offset, + int* mask_offset, + const int m, const int n) +{ + // src: [batch_size*max_seq_len, hidden_dim] + // tgt: [valid_word_num, hidden_dim] + remove_sequence_length_padding<<>>(src, tgt, tmp_mask_offset, mask_offset, n); +} + +template void remove_sequence_length_padding_kernelLauncher(const float* src, float* tgt, + const int* tmp_mask_offset, + int* mask_offset, const int m, + const int n); + +template void remove_sequence_length_padding_kernelLauncher(const half* src, half* tgt, + const int* tmp_mask_offset, + int* mask_offset, const int m, + const int n); + + + +/* add padding from transformer blocks to final output*/ + +template +__global__ void rebuild_sequence_length_padding(const T* src, T* tgt, + const int* mask_offset, + const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int tgt_seq_id = bid + mask_offset[bid]; + const int src_seq_id = bid; + + for(int i = tid; i < n; i += blockDim.x) + { + tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + + +template +void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* mask_offset, const int m, + const int n) +{ + // src: [valid_word_num, hidden_dim] + // tgt: [batch_size*max_seq_len, hidden_dim] + rebuild_sequence_length_padding<<>>(src, tgt, mask_offset, n); +} + + +template void rebuild_sequence_length_padding_kernelLauncher(const float* src, float* tgt, + const int* mask_offset, const int m, + const int n); + + +template void rebuild_sequence_length_padding_kernelLauncher(const half* src, half* tgt, + const int* mask_offset, const int m, + const int n); + + +/* FT transpose and remove padding */ + +template +__global__ +void transpose_rebuild_padding(T* Q, T* K, T* V, T* q_buf_, T* k_buf_, T* v_buf_, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int* mask_offset) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int bdim = blockDim.x; + + const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len; + const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; + const int tgt_head_id = tid / size_per_head; + const int tgt_hidden_id = tid % size_per_head; + + const int src_id = bid * bdim + tid; + const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + \ + tgt_head_id * seq_len * size_per_head + \ + tgt_seq_id * size_per_head + \ + tgt_hidden_id; + + q_buf_[tgt_id] = Q[src_id]; // + bias_Q[tid]; + k_buf_[tgt_id] = K[src_id]; // + bias_K[tid]; + v_buf_[tgt_id] = V[src_id]; // + bias_V[tid]; +} + +template +void transpose_rebuild_padding_kernelLauncher(T* Q, T* K, T* V, T* q_buf, T* k_buf, T* v_buf, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, + const int* mask_offset) +{ + const int k = head_num*size_per_head; + + if(std::is_same::value) + { + transpose_rebuild_padding<<>>(Q, K, V, q_buf, k_buf, v_buf, + batch_size, seq_len, head_num, size_per_head, mask_offset); + } + else + { + transpose_rebuild_padding<<>>((half2*)Q, + (half2*)K, (half2*)V, (half2*)q_buf, (half2*)k_buf, (half2*)v_buf, + batch_size, seq_len, head_num, size_per_head / 2, mask_offset); + } +} + +template +void transpose_rebuild_padding_kernelLauncher(float* Q, float* K, float* V, float* q_buf, float* k_buf, float* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset); + +template +void transpose_rebuild_padding_kernelLauncher(half* Q, half* K, half* V, half* q_buf, half* k_buf, half* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset); + + +/* FT rebuild padding and transpose */ +template +__global__ +void transpose_remove_padding(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head, + const int* mask_offset) +{ + // TODO: optimize this kernel? + // do remove_sequence_length_padding + const int tid = threadIdx.x; // batch * seq_len or valid_word_num + const int bid = blockIdx.x; // head_num * size_per_head + + const int src_batch_id = (bid + mask_offset[bid]) / seq_len; + const int src_seq_id = (bid + mask_offset[bid]) % seq_len; + + const int dst_seq_id = bid; + + const int head_id = tid / size_per_head; + const int hidden_id = tid % size_per_head; + dst[dst_seq_id * head_num * size_per_head + tid] = src[ src_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + src_seq_id * size_per_head + hidden_id]; +} + +template +void transpose_remove_padding_kernelLauncher(T* src, T* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset) +{ + int k = head_num * size_per_head; + if (std::is_same::value) + { + transpose_remove_padding<<>>(src, dst, + batch_size, seq_len, head_num, size_per_head, mask_offset); + } + else + { + transpose_remove_padding<<>>( + (half2*)src, (half2*)dst, + batch_size, seq_len, head_num, size_per_head / 2, mask_offset); + } +} + +template +void transpose_remove_padding_kernelLauncher(float* src, float* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset); + +template +void transpose_remove_padding_kernelLauncher(half* src, half* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset); \ No newline at end of file diff --git a/energon/kernel/cuda_native/csrc/transpose_pad_fusion_wrapper.cpp b/energon/kernel/cuda_native/csrc/transpose_pad_fusion_wrapper.cpp index 12e2fbe..3f5c352 100644 --- a/energon/kernel/cuda_native/csrc/transpose_pad_fusion_wrapper.cpp +++ b/energon/kernel/cuda_native/csrc/transpose_pad_fusion_wrapper.cpp @@ -3,7 +3,7 @@ // #include "ATen/cuda/CUDAContext.h" #include - +#include #include void transpose_pad(const float* src, @@ -46,12 +46,6 @@ torch::Tensor transpose_pad_wrapper(torch::Tensor src, } -// const float* src, const int batch_size, -// const int max_seq_len, -// const int64_t* seq_len_list, -// const int head_num, const int size_per_head, -// float* dst - torch::Tensor transpose_depad_wrapper(torch::Tensor src, int batch_size, int sum_seq, @@ -74,7 +68,136 @@ torch::Tensor transpose_depad_wrapper(torch::Tensor src, return dst; } +//Faster + +/* build offsets */ + +void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length, + const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset); + +void ft_build_padding_offsets_wrapper(torch::Tensor sequence_length, + int batch_size, + int max_seq_len, + torch::Tensor valid_word_num, + torch::Tensor tmp_mask_offset){ + CHECK_INPUT(sequence_length); + CHECK_INPUT(valid_word_num); + CHECK_INPUT(tmp_mask_offset); + + build_sequence_length_padding_offset_kernelLauncher(sequence_length.data_ptr(), batch_size, max_seq_len, valid_word_num.data_ptr(), + tmp_mask_offset.data_ptr()); +} + +/* remove padding from embedding layer to transformer blocks */ +template +void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* tmp_mask_offset, + int* mask_offset, const int m, + const int n); + +torch::Tensor ft_remove_padding_wrapper(torch::Tensor src, torch::Tensor tmp_mask_offset, torch::Tensor mask_offset, int valid_word_num, int hidden_dim){ + CHECK_INPUT(src); + CHECK_INPUT(tmp_mask_offset); + CHECK_INPUT(mask_offset); + + auto options = torch::TensorOptions().dtype(src.dtype()).device(torch::kCUDA).requires_grad(false); + auto tgt = torch::zeros({1, valid_word_num, hidden_dim}, options); + + if(src.dtype() == torch::kFloat32){ + + remove_sequence_length_padding_kernelLauncher((float*)src.data_ptr(), (float*)tgt.data_ptr(), (int*)tmp_mask_offset.data_ptr(), (int*)mask_offset.data_ptr(), valid_word_num, hidden_dim); + } + else + { + remove_sequence_length_padding_kernelLauncher((half*)src.data_ptr(), (half*)tgt.data_ptr(), (int*)tmp_mask_offset.data_ptr(), (int*)mask_offset.data_ptr(), valid_word_num, hidden_dim); + } + return tgt; +} + +/* add padding from transformer blocks to final output*/ +template +void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* mask_offset, const int m, + const int n); +torch::Tensor ft_rebuild_padding_wrapper(torch::Tensor src, torch::Tensor mask_offset, int valid_word_num, int hidden_dim, int batch_size, int max_seq_len){ + CHECK_INPUT(src); + CHECK_INPUT(mask_offset); + + auto options = torch::TensorOptions().dtype(src.dtype()).device(torch::kCUDA).requires_grad(false); + auto tgt = torch::zeros({batch_size, max_seq_len, hidden_dim}, options); + // auto tgt = torch::zeros_like(src); + + if(src.dtype() == torch::kFloat32){ + rebuild_sequence_length_padding_kernelLauncher((float*)src.data_ptr(), (float*)tgt.data_ptr(), (int*)mask_offset.data_ptr(), valid_word_num, hidden_dim); + } + else + { + rebuild_sequence_length_padding_kernelLauncher((half*)src.data_ptr(), (half*)tgt.data_ptr(), (int*)mask_offset.data_ptr(), valid_word_num, hidden_dim); + } + + return tgt; +} + +/* FT transpose and remove padding */ +template +void transpose_rebuild_padding_kernelLauncher(T* Q, T* K, T* V, T* q_buf, T* k_buf, T* v_buf, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, + const int* mask_offset); + + +void ft_transpose_rebuild_padding_wrapper(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor q_buf, torch::Tensor k_buf, torch::Tensor v_buf, + int batch_size, int seq_len, int head_num, int size_per_head, int valid_word_num, torch::Tensor mask_offset){ + CHECK_INPUT(Q); + CHECK_INPUT(K); + CHECK_INPUT(V); + CHECK_INPUT(q_buf); + CHECK_INPUT(k_buf); + CHECK_INPUT(v_buf); + CHECK_INPUT(mask_offset); + + if(Q.dtype() == torch::kFloat32){ + transpose_rebuild_padding_kernelLauncher((float*)Q.data_ptr(), (float*)K.data_ptr(), (float*)V.data_ptr(), (float*)q_buf.data_ptr(), (float*)k_buf.data_ptr(), (float*)v_buf.data_ptr(), + batch_size, seq_len, head_num, size_per_head, valid_word_num, (int*)mask_offset.data_ptr()); + }else + { + transpose_rebuild_padding_kernelLauncher((half*)Q.data_ptr(), (half*)K.data_ptr(), (half*)V.data_ptr(), (half*)q_buf.data_ptr(), (half*)k_buf.data_ptr(), (half*)v_buf.data_ptr(), + batch_size, seq_len, head_num, size_per_head, valid_word_num, (int*)mask_offset.data_ptr()); + } +} + + +/* FT rebuild padding and transpose */ +template +void transpose_remove_padding_kernelLauncher(T* src, T* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset); + +torch::Tensor ft_transpose_remove_padding_wrapper(torch::Tensor src, int valid_word_num, int batch_size, int seq_len, + int head_num, int size_per_head, torch::Tensor mask_offset){ + CHECK_INPUT(src); + CHECK_INPUT(mask_offset); + + auto options = torch::TensorOptions().dtype(src.dtype()).device(torch::kCUDA).requires_grad(false); + auto tgt = torch::zeros({1, valid_word_num, head_num*size_per_head}, options); + + if(src.dtype() == torch::kFloat32){ + transpose_remove_padding_kernelLauncher((float*)src.data_ptr(), (float*)tgt.data_ptr(), valid_word_num, batch_size, seq_len, head_num, size_per_head, (int*)mask_offset.data_ptr()); + }else + { + transpose_remove_padding_kernelLauncher((half*)src.data_ptr(), (half*)tgt.data_ptr(), valid_word_num, batch_size, seq_len, head_num, size_per_head, (int*)mask_offset.data_ptr()); + } + + return tgt; +} + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("transpose_pad_wrapper", &transpose_pad_wrapper, "Transpose and Padding"); - m.def("transpose_depad_wrapper", &transpose_depad_wrapper, "Transpose and Depadding"); + m.def("transpose_pad_wrapper", &transpose_pad_wrapper, "Turbo Transpose and Padding"); + m.def("transpose_depad_wrapper", &transpose_depad_wrapper, "Turbo Transpose and Depadding"); + m.def("ft_build_padding_offsets_wrapper", &ft_build_padding_offsets_wrapper, "Faster build offsets"); + m.def("ft_remove_padding_wrapper", &ft_remove_padding_wrapper, "Faster remove padding"); + m.def("ft_rebuild_padding_wrapper", &ft_rebuild_padding_wrapper, "Faster rebuild padding"); + m.def("ft_transpose_remove_padding_wrapper", &ft_transpose_remove_padding_wrapper, "Faster transpose and remove padding"); + m.def("ft_transpose_rebuild_padding_wrapper", &ft_transpose_rebuild_padding_wrapper, "Faster transpose and rebuild padding"); } \ No newline at end of file diff --git a/energon/kernel/cuda_native/transpose_pad.py b/energon/kernel/cuda_native/transpose_pad.py index 4ac1a79..1dda941 100644 --- a/energon/kernel/cuda_native/transpose_pad.py +++ b/energon/kernel/cuda_native/transpose_pad.py @@ -34,4 +34,41 @@ def depad(src, batch_size, seq_lens): return dst +# From FasterTransformer +def ft_build_padding_offsets(seq_lens, batch_size, max_seq_len, valid_word_num, tmp_mask_offset): + seq_lens = seq_lens.contiguous() + # tmp_mask_offset = tmp_mask_offset.contiguous() + + energon_transpose_pad.ft_build_padding_offsets_wrapper(seq_lens, batch_size, max_seq_len, valid_word_num, tmp_mask_offset) + +def ft_remove_padding(src, tmp_mask_offset, mask_offset, valid_word_num, hidden_dim): + src = src.contiguous() + # tmp_mask_offset = tmp_mask_offset.contiguous() + # mask_offset = mask_offset.contiguous() + + dst = energon_transpose_pad.ft_remove_padding_wrapper(src, tmp_mask_offset, mask_offset, valid_word_num, hidden_dim) + return dst + +def ft_rebuild_padding(src, mask_offset, valid_word_num, hidden_dim, batch_size, max_seq_len): + src = src.contiguous() + # mask_offset = mask_offset.contiguous() + + dst = energon_transpose_pad.ft_rebuild_padding_wrapper(src, mask_offset, valid_word_num, hidden_dim, batch_size, max_seq_len) + return dst + +def ft_transpose_rebuild_padding(Q, K, V, q_buf, k_buf, v_buf, batch_size, seq_len, head_num, size_per_head, valid_word_num, mask_offset): + Q = Q.contiguous() + K = K.contiguous() + V = V.contiguous() + q_buf = q_buf.contiguous() + k_buf = k_buf.contiguous() + v_buf = v_buf.contiguous() + + energon_transpose_pad.ft_transpose_rebuild_padding_wrapper(Q, K, V, q_buf, k_buf, v_buf, batch_size, seq_len, head_num, size_per_head, valid_word_num, mask_offset) + +def ft_transpose_remove_padding(src, valid_word_num, batch_size, seq_len, head_num, size_per_head, mask_offset): + src = src.contiguous() + + dst = energon_transpose_pad.ft_transpose_remove_padding_wrapper(src, valid_word_num, batch_size, seq_len, head_num, size_per_head, mask_offset) + return dst \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0c98398..ba35f20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ psutil packaging fastapi uvicorn==0.14 -typer \ No newline at end of file +typer +redis +scipy \ No newline at end of file diff --git a/tests/test_kernel/test_ft_transpose_pad.py b/tests/test_kernel/test_ft_transpose_pad.py new file mode 100644 index 0000000..2432dd8 --- /dev/null +++ b/tests/test_kernel/test_ft_transpose_pad.py @@ -0,0 +1,77 @@ +from energon.kernel import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding +import torch +import pytest + + +seq_lens = torch.tensor([24,127,31,65,24,127,31,65], dtype=torch.int).cuda() +batch_size = 8 +max_padding_size = 128 +head_size = 64 +head_num = 12 +hidden_size = head_num * head_size + + +def test_kernel(): + hidden_states_q = torch.rand(batch_size, max_padding_size, hidden_size).cuda() + hidden_states_k = torch.rand(batch_size, max_padding_size, hidden_size).cuda() + hidden_states_v = torch.rand(batch_size, max_padding_size, hidden_size).cuda() + + + tmp_mask_offset = torch.zeros(batch_size, max_padding_size, dtype=torch.int).cuda() + mask_offset = torch.zeros(batch_size, max_padding_size, dtype=torch.int).cuda() + valid_word_num = torch.zeros(1, dtype=torch.int).cuda() + + ft_build_padding_offsets(seq_lens, batch_size, max_padding_size, valid_word_num, tmp_mask_offset) + q = ft_remove_padding(hidden_states_q, tmp_mask_offset, mask_offset, valid_word_num[0].item(), hidden_size) + k = ft_remove_padding(hidden_states_k, tmp_mask_offset, mask_offset, valid_word_num[0].item(), hidden_size) + v = ft_remove_padding(hidden_states_v, tmp_mask_offset, mask_offset, valid_word_num[0].item(), hidden_size) + + new_qkv_shape = q.shape[:-1] + (head_num, head_size) + + q = q.view(new_qkv_shape) + k = k.view(new_qkv_shape) + v = v.view(new_qkv_shape) + print(q.size()) + + q_buf = torch.zeros(batch_size, head_num, max_padding_size, head_size).cuda() + k_buf = torch.zeros(batch_size, head_num, max_padding_size, head_size).cuda() + v_buf = torch.zeros(batch_size, head_num, max_padding_size, head_size).cuda() + + ft_transpose_rebuild_padding(q, k, v, q_buf, k_buf, v_buf, batch_size, max_padding_size, head_num, head_size, valid_word_num[0].item(), mask_offset) + + print(q_buf.size()) + + q_buf = ft_transpose_remove_padding(v_buf, valid_word_num[0].item(), batch_size, max_padding_size, head_num, head_size, mask_offset) + + print(q_buf.size()) + + q_buf = ft_rebuild_padding(q_buf, mask_offset, valid_word_num[0].item(), hidden_size, batch_size, max_padding_size) + + print(q_buf.size()) + + + + + + # ft_transpose_remove_padding() + + + + + + + + # void ft_transpose_remove_padding_wrapper(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor q_buf, torch::Tensor k_buf, torch::Tensor v_buf, + # int batch_size, int seq_len, int head_num, int size_per_head, int valid_word_num, torch::Tensor mask_offset){ + + + # print(new_hidden_states.size()) + + # def ft_remove_padding(src, tmp_mask_offset, mask_offset, valid_word_num, hidden_dim): + # def ft_rebuild_padding(src, mask_offset, valid_word_num, hidden_dim): + # def ft_transpose_remove_padding(Q, K, V, q_buf, k_buf, v_buf, batch_size, seq_len, head_num, size_per_head, valid_word_num, mask_offset): + # def ft_transpose_rebuild_padding(src, valid_word_num, batch_size, seq_len, head_num, size_per_head, mask_offset): + + +if __name__ == '__main__': + test_kernel() \ No newline at end of file