Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float4 with better vectorization for adamw.cu #268

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion dev/cuda/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Compile example:
nvcc adamw.cu -o adamw
nvcc -O3 --use_fast_math adamw.cu -o adamw

./adamw
./adamw 1

TODO(general):
amsgrad=True
Expand All @@ -24,6 +24,7 @@ thread coarsening/ILP
#include <time.h>
#include <cuda_runtime.h>
#include "common.h"
#include <cassert>


// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -98,6 +99,45 @@ __global__ void adamw_kernel2(float* params_memory, const float* grads_memory, f
}


// using float4 for memory coalescing based on kernel 2
__device__ inline float& vec_at(float4& vec, int index) {
return reinterpret_cast<float*>(&vec)[index];
}

__device__ inline float vec_at(const float4& vec, int index) {
return reinterpret_cast<const float*>(&vec)[index];
}

__global__ void adamw_kernel3(float4* params_memory, const float4* grads_memory, float4* m_memory, float4* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= num_parameters / 4) return; // guard
float4 grad = grads_memory[i];
float4 m = m_memory[i];
float4 v = v_memory[i];
float4 params = params_memory[i];

for (int j = 0; j < 4; ++j) {
const float& grad_j = vec_at(grad, j);
float m_j = vec_at(m, j);
float v_j = vec_at(v, j);

// update the first moment (momentum)
m_j = lerp(grad_j, m_j, beta1);
vec_at(m, j) = m_j;
// update the second moment (RMSprop)
v_j = lerp(grad_j * grad_j, v_j, beta2);
vec_at(v, j) = v_j;
m_j /= beta1_correction; // m_hat
v_j /= beta2_correction; // v_hat
vec_at(params, j) -= learning_rate * (m_j / (sqrtf(v_j) + eps) + weight_decay * vec_at(params, j));
}
m_memory[i] = m;
v_memory[i] = v;
params_memory[i] = params;
}


// ----------------------------------------------------------------------------
// kernel launcher

Expand All @@ -121,6 +161,17 @@ void adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_m
cudaCheck(cudaGetLastError());
}

// version 3: using float4 for memory coalescing
void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
assert(num_parameters % 4 == 0);
unsigned int block_size = 512;
unsigned int num_blocks = ceil_div(num_parameters / 4, (long) block_size);
adamw_kernel3<<<num_blocks, block_size>>>((float4*) params_memory, (float4*) grads_memory, (float4*) m_memory, (float4*) v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
cudaCheck(cudaGetLastError());
}

void adamw(int kernel_num,
float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters,
float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) {
Expand All @@ -136,6 +187,10 @@ void adamw(int kernel_num,
adamw_dispatch2(params_memory, grads_memory, m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
break;
case 3:
adamw_dispatch3(params_memory, grads_memory, m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down
Loading