From 6defe089abc7d7fab935b672d65d9100d39cef3d Mon Sep 17 00:00:00 2001 From: lancer Date: Fri, 26 Apr 2024 22:15:33 -0700 Subject: [PATCH 1/5] v1 with overload --- dev/cuda/adamw.cu | 133 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index 15e9048cd..cc0df813d 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -9,7 +9,7 @@ Compile example: nvcc adamw.cu -o adamw nvcc -O3 --use_fast_math adamw.cu -o adamw -./adamw +./adamw 1 TODO(general): amsgrad=True @@ -24,6 +24,7 @@ thread coarsening/ILP #include #include #include "common.h" +#include // ---------------------------------------------------------------------------- @@ -98,6 +99,122 @@ __global__ void adamw_kernel2(float* params_memory, const float* grads_memory, f } +// using float4 for memory coalescing based on kernel 2 +__device__ float& vec_at(float4& vec, int index) { + return reinterpret_cast(&vec)[index]; +} + +__device__ float vec_at(const float4& vec, int index) { + return reinterpret_cast(&vec)[index]; +} + +__device__ inline float4 lerp(float4 start, float4 end, float weight) { + float4 result; + result.x = fma(weight, end.x, fma(-weight, start.x, start.x)); + result.y = fma(weight, end.y, fma(-weight, start.y, start.y)); + result.z = fma(weight, end.z, fma(-weight, start.z, start.z)); + result.w = fma(weight, end.w, fma(-weight, start.w, start.w)); + return result; +} + +__device__ inline float4 operator*(const float4& a, const float4& b) { + float4 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + result.z = a.z * b.z; + result.w = a.w * b.w; + return result; +} + +__device__ inline float4 operator/(const float4& a, const float& b) { + float4 result; + result.x = a.x / b; + result.y = a.y / b; + result.z = a.z / b; + result.w = a.w / b; + return result; +} + +__device__ inline float4 operator/(const float4& a, const float4& b) { + float4 result; + result.x = a.x / b.x; + result.y = a.y / b.x; + result.z = a.z / b.x; + result.w = a.w / b.x; + return result; +} + +__device__ inline float4 operator*(const float4& a, const float& b) { + float4 result; + result.x = a.x * b; + result.y = a.y * b; + result.z = a.z * b; + result.w = a.w * b; + return result; +} + +__device__ inline float4 operator*(const float& a, const float4& b) { + return b * a; +} + +__device__ inline float4 operator+(const float4& a, const float4& b) { + float4 result; + result.x = a.x + b.x; + result.y = a.y + b.y; + result.z = a.z + b.z; + result.w = a.w + b.w; + return result; +} + +__device__ inline float4 operator+(const float4& a, const float& b) { + float4 result; + result.x = a.x + b; + result.y = a.y + b; + result.z = a.z + b; + result.w = a.w + b; + return result; +} + +__device__ inline float4 operator-(const float4& a, const float4& b) { + float4 result; + result.x = a.x - b.x; + result.y = a.y - b.y; + result.z = a.z - b.z; + result.w = a.w - b.w; + return result; +} + +__device__ inline float4 sqrtf(const float4& a) { + float4 result; + result.x = sqrtf(a.x); + result.y = sqrtf(a.y); + result.z = sqrtf(a.z); + result.w = sqrtf(a.w); + return result; +} + +__global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + assert(num_parameters % 4 == 0); + if (i >= num_parameters / 4) return; // guard + float4 grad = grads_memory[i]; + float4 m = m_memory[i]; + float4 v = v_memory[i]; + + // update the first moment (momentum) + m = lerp(grad, m, beta1); + m_memory[i] = m; + // update the second moment (RMSprop) + v = lerp(grad * grad, v, beta2); + v_memory[i] = v; + m = m / beta1_correction; // m_hat + v = v / beta2_correction; // v_hat + + params_memory[i] = params_memory[i] - learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]); + +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -121,6 +238,16 @@ void adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_m cudaCheck(cudaGetLastError()); } +// version 3: using float4 for memory coalescing +void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + unsigned int block_size = 512; + unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); + adamw_kernel3<<>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + cudaCheck(cudaGetLastError()); +} + void adamw(int kernel_num, float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) { @@ -136,6 +263,10 @@ void adamw(int kernel_num, adamw_dispatch2(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); break; + case 3: + adamw_dispatch3(params_memory, grads_memory, m_memory, v_memory, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + break; default: printf("Invalid kernel number\n"); exit(1); From c16a6a12d72f7aadbf84b7aad5726d21c253721d Mon Sep 17 00:00:00 2001 From: lancer Date: Fri, 26 Apr 2024 23:00:31 -0700 Subject: [PATCH 2/5] Include the adamw with float4 --- dev/cuda/adamw.cu | 164 +++++++++++++++++++++++----------------------- 1 file changed, 81 insertions(+), 83 deletions(-) diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index cc0df813d..ec6597b07 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -108,111 +108,94 @@ __device__ float vec_at(const float4& vec, int index) { return reinterpret_cast(&vec)[index]; } -__device__ inline float4 lerp(float4 start, float4 end, float weight) { - float4 result; - result.x = fma(weight, end.x, fma(-weight, start.x, start.x)); - result.y = fma(weight, end.y, fma(-weight, start.y, start.y)); - result.z = fma(weight, end.z, fma(-weight, start.z, start.z)); - result.w = fma(weight, end.w, fma(-weight, start.w, start.w)); - return result; +__global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + assert(num_parameters % 4 == 0); + if (i >= num_parameters / 4) return; // guard + float4 grad = grads_memory[i]; + float4 m = m_memory[i]; + float4 v = v_memory[i]; + float4 params = params_memory[i]; + + for (int j = 0; j < 4; ++j) { + const float& grad_j = vec_at(grad, j); + float m_j = vec_at(m, j); + float v_j = vec_at(v, j); + + // update the first moment (momentum) + m_j = lerp(grad_j, m_j, beta1); + vec_at(m, j) = m_j; + // update the second moment (RMSprop) + v_j = lerp(grad_j * grad_j, v_j, beta2); + vec_at(v, j) = v_j; + m_j /= beta1_correction; // m_hat + v_j /= beta2_correction; // v_hat + vec_at(params, j) -= learning_rate * (m_j / (sqrtf(v_j) + eps) + weight_decay * vec_at(params, j)); + } + m_memory[i] = m; + v_memory[i] = v; + params_memory[i] = params; } -__device__ inline float4 operator*(const float4& a, const float4& b) { - float4 result; - result.x = a.x * b.x; - result.y = a.y * b.y; - result.z = a.z * b.z; - result.w = a.w * b.w; - return result; +// kernel with overloaded operators for float4 +__device__ inline float4 operator+(const float4& a, const float4& b) { + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } -__device__ inline float4 operator/(const float4& a, const float& b) { - float4 result; - result.x = a.x / b; - result.y = a.y / b; - result.z = a.z / b; - result.w = a.w / b; - return result; +__device__ inline float4 operator-(const float4& a, const float4& b) { + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } -__device__ inline float4 operator/(const float4& a, const float4& b) { - float4 result; - result.x = a.x / b.x; - result.y = a.y / b.x; - result.z = a.z / b.x; - result.w = a.w / b.x; - return result; +__device__ inline float4 operator*(const float4& a, const float4& b) { + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } -__device__ inline float4 operator*(const float4& a, const float& b) { - float4 result; - result.x = a.x * b; - result.y = a.y * b; - result.z = a.z * b; - result.w = a.w * b; - return result; +__device__ inline float4 operator/(const float4& a, const float4& b) { + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } -__device__ inline float4 operator*(const float& a, const float4& b) { - return b * a; +__device__ inline float4 operator/(const float4& a, float b) { + return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); } -__device__ inline float4 operator+(const float4& a, const float4& b) { - float4 result; - result.x = a.x + b.x; - result.y = a.y + b.y; - result.z = a.z + b.z; - result.w = a.w + b.w; - return result; +__device__ inline float4 operator*(const float4& a, float b) { + return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); } -__device__ inline float4 operator+(const float4& a, const float& b) { - float4 result; - result.x = a.x + b; - result.y = a.y + b; - result.z = a.z + b; - result.w = a.w + b; - return result; +__device__ inline float4 operator*(float a, const float4& b) { + return make_float4(a * b.x, a * b.y, a * b.z, a * b.w); } -__device__ inline float4 operator-(const float4& a, const float4& b) { - float4 result; - result.x = a.x - b.x; - result.y = a.y - b.y; - result.z = a.z - b.z; - result.w = a.w - b.w; - return result; +__device__ inline float4 operator+(const float4& a, float b) { + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); } __device__ inline float4 sqrtf(const float4& a) { - float4 result; - result.x = sqrtf(a.x); - result.y = sqrtf(a.y); - result.z = sqrtf(a.z); - result.w = sqrtf(a.w); - return result; + return make_float4(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z), sqrtf(a.w)); } -__global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - assert(num_parameters % 4 == 0); - if (i >= num_parameters / 4) return; // guard - float4 grad = grads_memory[i]; - float4 m = m_memory[i]; - float4 v = v_memory[i]; - - // update the first moment (momentum) - m = lerp(grad, m, beta1); - m_memory[i] = m; - // update the second moment (RMSprop) - v = lerp(grad * grad, v, beta2); - v_memory[i] = v; - m = m / beta1_correction; // m_hat - v = v / beta2_correction; // v_hat - - params_memory[i] = params_memory[i] - learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]); +__device__ inline float4 lerp(const float4& start, const float4& end, float weight) { + return make_float4(lerp(start.x, end.x, weight), lerp(start.y, end.y, weight), lerp(start.z, end.z, weight), lerp(start.w, end.w, weight)); +} +__global__ void adamw_kernel4(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + assert(num_parameters % 4 == 0); + if (i >= num_parameters / 4) return; // guard + float4 grad = grads_memory[i]; + float4 m = m_memory[i]; + float4 v = v_memory[i]; + // update the first moment (momentum) + m = lerp(grad, m, beta1); + m_memory[i] = m; + // update the second moment (RMSprop) + v = lerp(grad * grad, v, beta2); + v_memory[i] = v; + m = m / beta1_correction; // m_hat + v = v / beta2_correction; // v_hat + params_memory[i] = params_memory[i] - learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]); } // ---------------------------------------------------------------------------- @@ -248,6 +231,17 @@ void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_m cudaCheck(cudaGetLastError()); } + +// version 4: using overloaded operators for float4 +void adamw_dispatch4(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + unsigned int block_size = 512; + unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); + adamw_kernel4<<>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + cudaCheck(cudaGetLastError()); +} + void adamw(int kernel_num, float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) { @@ -267,6 +261,10 @@ void adamw(int kernel_num, adamw_dispatch3(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); break; + case 4: + adamw_dispatch4(params_memory, grads_memory, m_memory, v_memory, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + break; default: printf("Invalid kernel number\n"); exit(1); From 17638a45a41f8f5a71c40d81532987341fa34d57 Mon Sep 17 00:00:00 2001 From: lancer Date: Fri, 26 Apr 2024 23:03:06 -0700 Subject: [PATCH 3/5] Update the kernel 3 --- dev/cuda/adamw.cu | 74 ----------------------------------------------- 1 file changed, 74 deletions(-) diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index ec6597b07..f2a6e61a3 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -138,65 +138,6 @@ __global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, params_memory[i] = params; } -// kernel with overloaded operators for float4 -__device__ inline float4 operator+(const float4& a, const float4& b) { - return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -__device__ inline float4 operator-(const float4& a, const float4& b) { - return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); -} - -__device__ inline float4 operator*(const float4& a, const float4& b) { - return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); -} - -__device__ inline float4 operator/(const float4& a, const float4& b) { - return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); -} - -__device__ inline float4 operator/(const float4& a, float b) { - return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); -} - -__device__ inline float4 operator*(const float4& a, float b) { - return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); -} - -__device__ inline float4 operator*(float a, const float4& b) { - return make_float4(a * b.x, a * b.y, a * b.z, a * b.w); -} - -__device__ inline float4 operator+(const float4& a, float b) { - return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); -} - -__device__ inline float4 sqrtf(const float4& a) { - return make_float4(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z), sqrtf(a.w)); -} - -__device__ inline float4 lerp(const float4& start, const float4& end, float weight) { - return make_float4(lerp(start.x, end.x, weight), lerp(start.y, end.y, weight), lerp(start.z, end.z, weight), lerp(start.w, end.w, weight)); -} - -__global__ void adamw_kernel4(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - assert(num_parameters % 4 == 0); - if (i >= num_parameters / 4) return; // guard - float4 grad = grads_memory[i]; - float4 m = m_memory[i]; - float4 v = v_memory[i]; - // update the first moment (momentum) - m = lerp(grad, m, beta1); - m_memory[i] = m; - // update the second moment (RMSprop) - v = lerp(grad * grad, v, beta2); - v_memory[i] = v; - m = m / beta1_correction; // m_hat - v = v / beta2_correction; // v_hat - params_memory[i] = params_memory[i] - learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]); -} // ---------------------------------------------------------------------------- // kernel launcher @@ -231,17 +172,6 @@ void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_m cudaCheck(cudaGetLastError()); } - -// version 4: using overloaded operators for float4 -void adamw_dispatch4(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { - unsigned int block_size = 512; - unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); - adamw_kernel4<<>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); - cudaCheck(cudaGetLastError()); -} - void adamw(int kernel_num, float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) { @@ -261,10 +191,6 @@ void adamw(int kernel_num, adamw_dispatch3(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); break; - case 4: - adamw_dispatch4(params_memory, grads_memory, m_memory, v_memory, num_parameters, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); - break; default: printf("Invalid kernel number\n"); exit(1); From 5ca73164c1acd5a59087cecce0d2555b3cf3e61e Mon Sep 17 00:00:00 2001 From: lancer Date: Sat, 27 Apr 2024 20:00:36 -0700 Subject: [PATCH 4/5] amend the kernel for more efficiency --- dev/cuda/adamw.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index f2a6e61a3..d163a22d3 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -100,18 +100,17 @@ __global__ void adamw_kernel2(float* params_memory, const float* grads_memory, f // using float4 for memory coalescing based on kernel 2 -__device__ float& vec_at(float4& vec, int index) { +__device__ inline float& vec_at(float4& vec, int index) { return reinterpret_cast(&vec)[index]; } -__device__ float vec_at(const float4& vec, int index) { +__device__ inline float vec_at(const float4& vec, int index) { return reinterpret_cast(&vec)[index]; } __global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { int i = blockIdx.x * blockDim.x + threadIdx.x; - assert(num_parameters % 4 == 0); if (i >= num_parameters / 4) return; // guard float4 grad = grads_memory[i]; float4 m = m_memory[i]; @@ -165,6 +164,7 @@ void adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_m // version 3: using float4 for memory coalescing void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + assert(num_parameters % 4 == 0); unsigned int block_size = 512; unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); adamw_kernel3<<>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters, From 25dc803d9f66306e4efd5f0544cbc2e416f447fd Mon Sep 17 00:00:00 2001 From: lancer Date: Sat, 27 Apr 2024 20:29:58 -0700 Subject: [PATCH 5/5] amend the grid size --- dev/cuda/adamw.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index d163a22d3..e96927c6a 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -166,7 +166,7 @@ void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_m float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { assert(num_parameters % 4 == 0); unsigned int block_size = 512; - unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); + unsigned int num_blocks = ceil_div(num_parameters / 4, (long) block_size); adamw_kernel3<<>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); cudaCheck(cudaGetLastError());