diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index 15e9048cd..e96927c6a 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -9,7 +9,7 @@ Compile example: nvcc adamw.cu -o adamw nvcc -O3 --use_fast_math adamw.cu -o adamw -./adamw +./adamw 1 TODO(general): amsgrad=True @@ -24,6 +24,7 @@ thread coarsening/ILP #include #include #include "common.h" +#include // ---------------------------------------------------------------------------- @@ -98,6 +99,45 @@ __global__ void adamw_kernel2(float* params_memory, const float* grads_memory, f } +// using float4 for memory coalescing based on kernel 2 +__device__ inline float& vec_at(float4& vec, int index) { + return reinterpret_cast(&vec)[index]; +} + +__device__ inline float vec_at(const float4& vec, int index) { + return reinterpret_cast(&vec)[index]; +} + +__global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= num_parameters / 4) return; // guard + float4 grad = grads_memory[i]; + float4 m = m_memory[i]; + float4 v = v_memory[i]; + float4 params = params_memory[i]; + + for (int j = 0; j < 4; ++j) { + const float& grad_j = vec_at(grad, j); + float m_j = vec_at(m, j); + float v_j = vec_at(v, j); + + // update the first moment (momentum) + m_j = lerp(grad_j, m_j, beta1); + vec_at(m, j) = m_j; + // update the second moment (RMSprop) + v_j = lerp(grad_j * grad_j, v_j, beta2); + vec_at(v, j) = v_j; + m_j /= beta1_correction; // m_hat + v_j /= beta2_correction; // v_hat + vec_at(params, j) -= learning_rate * (m_j / (sqrtf(v_j) + eps) + weight_decay * vec_at(params, j)); + } + m_memory[i] = m; + v_memory[i] = v; + params_memory[i] = params; +} + + // ---------------------------------------------------------------------------- // kernel launcher @@ -121,6 +161,17 @@ void adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_m cudaCheck(cudaGetLastError()); } +// version 3: using float4 for memory coalescing +void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + assert(num_parameters % 4 == 0); + unsigned int block_size = 512; + unsigned int num_blocks = ceil_div(num_parameters / 4, (long) block_size); + adamw_kernel3<<>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + cudaCheck(cudaGetLastError()); +} + void adamw(int kernel_num, float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) { @@ -136,6 +187,10 @@ void adamw(int kernel_num, adamw_dispatch2(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); break; + case 3: + adamw_dispatch3(params_memory, grads_memory, m_memory, v_memory, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + break; default: printf("Invalid kernel number\n"); exit(1);