Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #41 from hpcaitech/feature/variable_len
Browse files Browse the repository at this point in the history
block without timeout
  • Loading branch information
MaruyamaAya authored May 4, 2022
2 parents 4c05b72 + 0ebe9b5 commit ac8245b
Show file tree
Hide file tree
Showing 9 changed files with 484 additions and 15 deletions.
4 changes: 2 additions & 2 deletions energon/engine/bert_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self, key, inputs):
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

# different threads ask for a single lock
self.lock.acquire(timeout=3)
self.lock.acquire()

sample, pipe_meta = self.pipe_msg_queue.top(self.key.val)
self.key.addOne()
Expand Down Expand Up @@ -128,4 +128,4 @@ def run(self, key, inputs):
self.lock.release()
return None



4 changes: 2 additions & 2 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run(self, key, inputs):
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire(timeout=3)
self.lock.acquire()
sample, pipe_meta = self.pipe_msg_queue.top(self.key.val)
self.key.addOne()

Expand Down Expand Up @@ -136,4 +136,4 @@ def run(self, key, inputs):
self.lock.release()
return None



4 changes: 3 additions & 1 deletion energon/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .cuda_native import transpose_pad, transpose_depad, depad, scale_mask_softmax
from .cuda_native import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding

__all__ = [
"transpose_pad", "transpose_depad", "depad", "scale_mask_softmax"
"transpose_pad", "transpose_depad", "depad", "scale_mask_softmax",
"ft_build_padding_offsets", "ft_remove_padding", "ft_rebuild_padding", "ft_transpose_remove_padding", "ft_transpose_rebuild_padding"
]
1 change: 1 addition & 0 deletions energon/kernel/cuda_native/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .transpose_pad import transpose_pad, transpose_depad, depad
from .transpose_pad import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding
from .scale_mask_softmax import scale_mask_softmax
from .layer_norm import MixedFusedLayerNorm as LayerNorm
228 changes: 227 additions & 1 deletion energon/kernel/cuda_native/csrc/transpose_pad_fusion_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// transpose and padding/depadding fusion to reduce the memory move.

#include <cuda_runtime.h>
#include <cuda_fp16.h>

// from turbo
__global__ void transpose_depad_kernel(const float* src, const int batch_size,
const int seq_len,
const int64_t* seq_lens,
Expand Down Expand Up @@ -85,4 +87,228 @@ void transpose_pad(const float* src,
dim3 dimBlock(size_per_head);

transpose_pad_kernel<<<dimGrid, dimBlock>>>(src, batch_size, seq_len, seq_lens, head_num, size_per_head, dst);
}
}


// from faster

/* create offsets */

__global__ void build_sequence_length_padding_offset(const int* sequence_length,
const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset)
{
// do cumulated sum
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
for(int i = 0; i < batch_size; i++)
{
const int seq_len = sequence_length[i];
for(int j = 0; j < seq_len; j++)
{
tmp_mask_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
valid_word_num[0] = total_seq_len;
}

void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length,
const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset)
{
build_sequence_length_padding_offset<<<1, 1>>>(sequence_length,
batch_size, max_seq_len, valid_word_num, tmp_mask_offset);
}


/* remove padding from embedding layer to transformer blocks */
template<typename T>
__global__ void remove_sequence_length_padding(const T* src, T* tgt,
const int* tmp_mask_offset,
int* mask_offset,
const int n)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
mask_offset[bid] = tmp_mask_offset[bid];
const int src_seq_id = bid + mask_offset[bid];
const int tgt_seq_id = bid;

for(int i = tid; i < n; i += blockDim.x)
{
tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
}
}

template<typename T>
void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt,
const int* tmp_mask_offset,
int* mask_offset,
const int m, const int n)
{
// src: [batch_size*max_seq_len, hidden_dim]
// tgt: [valid_word_num, hidden_dim]
remove_sequence_length_padding<<<m, 256>>>(src, tgt, tmp_mask_offset, mask_offset, n);
}

template void remove_sequence_length_padding_kernelLauncher(const float* src, float* tgt,
const int* tmp_mask_offset,
int* mask_offset, const int m,
const int n);

template void remove_sequence_length_padding_kernelLauncher(const half* src, half* tgt,
const int* tmp_mask_offset,
int* mask_offset, const int m,
const int n);



/* add padding from transformer blocks to final output*/

template<typename T>
__global__ void rebuild_sequence_length_padding(const T* src, T* tgt,
const int* mask_offset,
const int n)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int tgt_seq_id = bid + mask_offset[bid];
const int src_seq_id = bid;

for(int i = tid; i < n; i += blockDim.x)
{
tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
}
}


template<typename T>
void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt,
const int* mask_offset, const int m,
const int n)
{
// src: [valid_word_num, hidden_dim]
// tgt: [batch_size*max_seq_len, hidden_dim]
rebuild_sequence_length_padding<<<m, 256>>>(src, tgt, mask_offset, n);
}


template void rebuild_sequence_length_padding_kernelLauncher(const float* src, float* tgt,
const int* mask_offset, const int m,
const int n);


template void rebuild_sequence_length_padding_kernelLauncher(const half* src, half* tgt,
const int* mask_offset, const int m,
const int n);


/* FT transpose and remove padding */

template<typename T>
__global__
void transpose_rebuild_padding(T* Q, T* K, T* V, T* q_buf_, T* k_buf_, T* v_buf_,
const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int* mask_offset)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int bdim = blockDim.x;

const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int tgt_head_id = tid / size_per_head;
const int tgt_hidden_id = tid % size_per_head;

const int src_id = bid * bdim + tid;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + \
tgt_head_id * seq_len * size_per_head + \
tgt_seq_id * size_per_head + \
tgt_hidden_id;

q_buf_[tgt_id] = Q[src_id]; // + bias_Q[tid];
k_buf_[tgt_id] = K[src_id]; // + bias_K[tid];
v_buf_[tgt_id] = V[src_id]; // + bias_V[tid];
}

template<typename T>
void transpose_rebuild_padding_kernelLauncher(T* Q, T* K, T* V, T* q_buf, T* k_buf, T* v_buf,
const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num,
const int* mask_offset)
{
const int k = head_num*size_per_head;

if(std::is_same<T, float>::value)
{
transpose_rebuild_padding<<<valid_word_num, k>>>(Q, K, V, q_buf, k_buf, v_buf,
batch_size, seq_len, head_num, size_per_head, mask_offset);
}
else
{
transpose_rebuild_padding<<<valid_word_num, k / 2>>>((half2*)Q,
(half2*)K, (half2*)V, (half2*)q_buf, (half2*)k_buf, (half2*)v_buf,
batch_size, seq_len, head_num, size_per_head / 2, mask_offset);
}
}

template
void transpose_rebuild_padding_kernelLauncher(float* Q, float* K, float* V, float* q_buf, float* k_buf, float* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset);

template
void transpose_rebuild_padding_kernelLauncher(half* Q, half* K, half* V, half* q_buf, half* k_buf, half* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset);


/* FT rebuild padding and transpose */
template<typename T>
__global__
void transpose_remove_padding(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head,
const int* mask_offset)
{
// TODO: optimize this kernel?
// do remove_sequence_length_padding
const int tid = threadIdx.x; // batch * seq_len or valid_word_num
const int bid = blockIdx.x; // head_num * size_per_head

const int src_batch_id = (bid + mask_offset[bid]) / seq_len;
const int src_seq_id = (bid + mask_offset[bid]) % seq_len;

const int dst_seq_id = bid;

const int head_id = tid / size_per_head;
const int hidden_id = tid % size_per_head;
dst[dst_seq_id * head_num * size_per_head + tid] = src[ src_batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head + src_seq_id * size_per_head + hidden_id];
}

template<typename T>
void transpose_remove_padding_kernelLauncher(T* src, T* dst, const int valid_word_num,
const int batch_size, const int seq_len,
const int head_num, const int size_per_head,
const int* mask_offset)
{
int k = head_num * size_per_head;
if (std::is_same<T, float>::value)
{
transpose_remove_padding<<<valid_word_num, k>>>(src, dst,
batch_size, seq_len, head_num, size_per_head, mask_offset);
}
else
{
transpose_remove_padding<half2><<<valid_word_num, k / 2>>>(
(half2*)src, (half2*)dst,
batch_size, seq_len, head_num, size_per_head / 2, mask_offset);
}
}

template
void transpose_remove_padding_kernelLauncher(float* src, float* dst, const int valid_word_num,
const int batch_size, const int seq_len,
const int head_num, const int size_per_head,
const int* mask_offset);

template
void transpose_remove_padding_kernelLauncher(half* src, half* dst, const int valid_word_num,
const int batch_size, const int seq_len,
const int head_num, const int size_per_head,
const int* mask_offset);
Loading

0 comments on commit ac8245b

Please sign in to comment.