Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

block without timeout #41

Merged
merged 2 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions energon/engine/bert_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self, key, inputs):
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

# different threads ask for a single lock
self.lock.acquire(timeout=3)
self.lock.acquire()

sample, pipe_meta = self.pipe_msg_queue.top(self.key.val)
self.key.addOne()
Expand Down Expand Up @@ -128,4 +128,4 @@ def run(self, key, inputs):
self.lock.release()
return None



4 changes: 2 additions & 2 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run(self, key, inputs):
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire(timeout=3)
self.lock.acquire()
sample, pipe_meta = self.pipe_msg_queue.top(self.key.val)
self.key.addOne()

Expand Down Expand Up @@ -136,4 +136,4 @@ def run(self, key, inputs):
self.lock.release()
return None



4 changes: 3 additions & 1 deletion energon/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .cuda_native import transpose_pad, transpose_depad, depad, scale_mask_softmax
from .cuda_native import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding

__all__ = [
"transpose_pad", "transpose_depad", "depad", "scale_mask_softmax"
"transpose_pad", "transpose_depad", "depad", "scale_mask_softmax",
"ft_build_padding_offsets", "ft_remove_padding", "ft_rebuild_padding", "ft_transpose_remove_padding", "ft_transpose_rebuild_padding"
]
1 change: 1 addition & 0 deletions energon/kernel/cuda_native/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .transpose_pad import transpose_pad, transpose_depad, depad
from .transpose_pad import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding
from .scale_mask_softmax import scale_mask_softmax
from .layer_norm import MixedFusedLayerNorm as LayerNorm
228 changes: 227 additions & 1 deletion energon/kernel/cuda_native/csrc/transpose_pad_fusion_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// transpose and padding/depadding fusion to reduce the memory move.

#include <cuda_runtime.h>
#include <cuda_fp16.h>

// from turbo
__global__ void transpose_depad_kernel(const float* src, const int batch_size,
const int seq_len,
const int64_t* seq_lens,
Expand Down Expand Up @@ -85,4 +87,228 @@ void transpose_pad(const float* src,
dim3 dimBlock(size_per_head);

transpose_pad_kernel<<<dimGrid, dimBlock>>>(src, batch_size, seq_len, seq_lens, head_num, size_per_head, dst);
}
}


// from faster

/* create offsets */

__global__ void build_sequence_length_padding_offset(const int* sequence_length,
const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset)
{
// do cumulated sum
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
for(int i = 0; i < batch_size; i++)
{
const int seq_len = sequence_length[i];
for(int j = 0; j < seq_len; j++)
{
tmp_mask_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
valid_word_num[0] = total_seq_len;
}

void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length,
const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset)
{
build_sequence_length_padding_offset<<<1, 1>>>(sequence_length,
batch_size, max_seq_len, valid_word_num, tmp_mask_offset);
}


/* remove padding from embedding layer to transformer blocks */
template<typename T>
__global__ void remove_sequence_length_padding(const T* src, T* tgt,
const int* tmp_mask_offset,
int* mask_offset,
const int n)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
mask_offset[bid] = tmp_mask_offset[bid];
const int src_seq_id = bid + mask_offset[bid];
const int tgt_seq_id = bid;

for(int i = tid; i < n; i += blockDim.x)
{
tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
}
}

template<typename T>
void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt,
const int* tmp_mask_offset,
int* mask_offset,
const int m, const int n)
{
// src: [batch_size*max_seq_len, hidden_dim]
// tgt: [valid_word_num, hidden_dim]
remove_sequence_length_padding<<<m, 256>>>(src, tgt, tmp_mask_offset, mask_offset, n);
}

template void remove_sequence_length_padding_kernelLauncher(const float* src, float* tgt,
const int* tmp_mask_offset,
int* mask_offset, const int m,
const int n);

template void remove_sequence_length_padding_kernelLauncher(const half* src, half* tgt,
const int* tmp_mask_offset,
int* mask_offset, const int m,
const int n);



/* add padding from transformer blocks to final output*/

template<typename T>
__global__ void rebuild_sequence_length_padding(const T* src, T* tgt,
const int* mask_offset,
const int n)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int tgt_seq_id = bid + mask_offset[bid];
const int src_seq_id = bid;

for(int i = tid; i < n; i += blockDim.x)
{
tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
}
}


template<typename T>
void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt,
const int* mask_offset, const int m,
const int n)
{
// src: [valid_word_num, hidden_dim]
// tgt: [batch_size*max_seq_len, hidden_dim]
rebuild_sequence_length_padding<<<m, 256>>>(src, tgt, mask_offset, n);
}


template void rebuild_sequence_length_padding_kernelLauncher(const float* src, float* tgt,
const int* mask_offset, const int m,
const int n);


template void rebuild_sequence_length_padding_kernelLauncher(const half* src, half* tgt,
const int* mask_offset, const int m,
const int n);


/* FT transpose and remove padding */

template<typename T>
__global__
void transpose_rebuild_padding(T* Q, T* K, T* V, T* q_buf_, T* k_buf_, T* v_buf_,
const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int* mask_offset)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int bdim = blockDim.x;

const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int tgt_head_id = tid / size_per_head;
const int tgt_hidden_id = tid % size_per_head;

const int src_id = bid * bdim + tid;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + \
tgt_head_id * seq_len * size_per_head + \
tgt_seq_id * size_per_head + \
tgt_hidden_id;

q_buf_[tgt_id] = Q[src_id]; // + bias_Q[tid];
k_buf_[tgt_id] = K[src_id]; // + bias_K[tid];
v_buf_[tgt_id] = V[src_id]; // + bias_V[tid];
}

template<typename T>
void transpose_rebuild_padding_kernelLauncher(T* Q, T* K, T* V, T* q_buf, T* k_buf, T* v_buf,
const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num,
const int* mask_offset)
{
const int k = head_num*size_per_head;

if(std::is_same<T, float>::value)
{
transpose_rebuild_padding<<<valid_word_num, k>>>(Q, K, V, q_buf, k_buf, v_buf,
batch_size, seq_len, head_num, size_per_head, mask_offset);
}
else
{
transpose_rebuild_padding<<<valid_word_num, k / 2>>>((half2*)Q,
(half2*)K, (half2*)V, (half2*)q_buf, (half2*)k_buf, (half2*)v_buf,
batch_size, seq_len, head_num, size_per_head / 2, mask_offset);
}
}

template
void transpose_rebuild_padding_kernelLauncher(float* Q, float* K, float* V, float* q_buf, float* k_buf, float* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset);

template
void transpose_rebuild_padding_kernelLauncher(half* Q, half* K, half* V, half* q_buf, half* k_buf, half* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset);


/* FT rebuild padding and transpose */
template<typename T>
__global__
void transpose_remove_padding(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head,
const int* mask_offset)
{
// TODO: optimize this kernel?
// do remove_sequence_length_padding
const int tid = threadIdx.x; // batch * seq_len or valid_word_num
const int bid = blockIdx.x; // head_num * size_per_head

const int src_batch_id = (bid + mask_offset[bid]) / seq_len;
const int src_seq_id = (bid + mask_offset[bid]) % seq_len;

const int dst_seq_id = bid;

const int head_id = tid / size_per_head;
const int hidden_id = tid % size_per_head;
dst[dst_seq_id * head_num * size_per_head + tid] = src[ src_batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head + src_seq_id * size_per_head + hidden_id];
}

template<typename T>
void transpose_remove_padding_kernelLauncher(T* src, T* dst, const int valid_word_num,
const int batch_size, const int seq_len,
const int head_num, const int size_per_head,
const int* mask_offset)
{
int k = head_num * size_per_head;
if (std::is_same<T, float>::value)
{
transpose_remove_padding<<<valid_word_num, k>>>(src, dst,
batch_size, seq_len, head_num, size_per_head, mask_offset);
}
else
{
transpose_remove_padding<half2><<<valid_word_num, k / 2>>>(
(half2*)src, (half2*)dst,
batch_size, seq_len, head_num, size_per_head / 2, mask_offset);
}
}

template
void transpose_remove_padding_kernelLauncher(float* src, float* dst, const int valid_word_num,
const int batch_size, const int seq_len,
const int head_num, const int size_per_head,
const int* mask_offset);

template
void transpose_remove_padding_kernelLauncher(half* src, half* dst, const int valid_word_num,
const int batch_size, const int seq_len,
const int head_num, const int size_per_head,
const int* mask_offset);
Loading