Skip to content

Commit

Permalink
Merge pull request #265 from karpathy/feature/load_bf16
Browse files Browse the repository at this point in the history
load bf16 directly, and some "quality of life" handling of fp32/fp16/bf16 precisions
  • Loading branch information
karpathy authored Apr 28, 2024
2 parents 12da2c1 + 9d70d9a commit d95b8d8
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 170 deletions.
16 changes: 15 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ else
endif
endif

# Precision settings, default to bf16 but ability to override
PRECISION ?= BF16
VALID_PRECISIONS := FP32 FP16 BF16
ifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),)
$(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS))
endif
ifeq ($(PRECISION), FP32)
PFLAGS = -DENABLE_FP32
else ifeq ($(PRECISION), FP16)
PFLAGS = -DENABLE_FP16
else
PFLAGS = -DENABLE_BF16
endif

# PHONY means these targets will always be executed
.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu

Expand All @@ -108,7 +122,7 @@ test_gpt2: test_gpt2.c
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@

train_gpt2cu: train_gpt2.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@

train_gpt2fp32cu: train_gpt2_fp32.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
Expand Down
3 changes: 2 additions & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ For example, I have NVIDIA Nsight Compute installed on my Mac, and I rsync
the profile.ncu-rep from a cloud box to local to pretty view.
*/

#define ENABLE_BF16
#define TESTING
#include "train_gpt2.cu"

Expand Down Expand Up @@ -51,7 +52,7 @@ int main() {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");

int B = 4;
int T = 1024;
Expand Down
220 changes: 140 additions & 80 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
#define ENABLE_BF16
#define TESTING
#include "train_gpt2.cu"

// poor man's tensor checker
int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) {
int print_upto = 5;
// a is the calculated tensor, b is the reference tensor
int print_upto = 10;
int ok = 1;
float max_diff = 0.0f;
float max_rel_error = 0.0f;
float max_a = 0.0f;
float max_b = 0.0f;
printf("%s\n", label);
for (int i = 0; i < n; i++) {
if (fabsf(a[i] - b[i]) <= threshold) {
float diff = fabsf(a[i] - b[i]);
if (diff > max_diff) {
max_diff = diff;
float denom = fabsf(b[i]);
max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom;
max_a = a[i];
max_b = b[i];
}
if (diff <= threshold) {
if (i < print_upto) { printf("OK "); }
} else {
if (i < print_upto) { printf("NOT OK "); }
Expand All @@ -17,13 +31,58 @@ int check_tensor(float *a, float *b, int n, const char* label, float threshold=1
}
// print the final result
if (ok) {
printf("TENSOR OK\n");
printf("TENSOR OK, max diff: %e, with rel error: %e (calculated=%f, ref=%f)\n",
max_diff, max_rel_error, max_a, max_b);
} else {
printf("TENSOR NOT OK\n");
printf("TENSOR NOT OK, max diff: %e, with rel error: %e (calculated=%f, ref=%f)\n",
max_diff, max_rel_error, max_a, max_b);
}
return ok;
}

// the same tensors as in the train file, but in float, which are used as reference
typedef struct {
float* wte; // (V, C)
float* wpe; // (maxT, C)
float* ln1w; // (L, C)
float* ln1b; // (L, C)
float* qkvw; // (L, 3*C, C)
float* qkvb; // (L, 3*C)
float* attprojw; // (L, C, C)
float* attprojb; // (L, C)
float* ln2w; // (L, C)
float* ln2b; // (L, C)
float* fcw; // (L, 4*C, C)
float* fcb; // (L, 4*C)
float* fcprojw; // (L, C, 4*C)
float* fcprojb; // (L, C)
float* lnfw; // (C)
float* lnfb; // (C)
} FloatParameterTensors;
static_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!");

// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU
float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) {
// calculate the total number of parameters
size_t num_parameters = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += param_sizes[i];
}
// everything is float so number of bytes to allocate is a simple multiplication
float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float));
float** ptrs[] = {
&params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,
&params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,
&params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb
};
float* params_memory_iterator = params_memory;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}
return params_memory;
}

int main(int argc, char *argv[]) {

// set up the device
Expand All @@ -48,46 +107,46 @@ int main(int argc, char *argv[]) {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");

// int C = model.config.channels;
int V = model.config.vocab_size;
int maxT = model.config.max_seq_len;
// int L = model.config.num_layers;
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");
size_t V = model.config.vocab_size;
size_t maxT = model.config.max_seq_len;
size_t L = model.config.num_layers;
size_t C = model.config.channels;

// load additional information that we will use for debugging and error checking
FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb");
int state_header[256];
freadCheck(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); }
if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); }
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(EXIT_FAILURE); }
if (state_header[1] != 1) { printf("Bad version in state file"); exit(EXIT_FAILURE); }
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
assert(0 <= T && T <= maxT);
printf("[State]\n");
printf("batch_size: %d\n", B);
printf("seq_len: %d\n", T);

ParameterTensors expected_grads; // will be read from file (from PyTorch)
ParameterTensors calculated_grads; // will be calculated by us
float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_elements, model.param_sizeof, 0);
float* calculated_grads_memory = malloc_and_point_parameters(&calculated_grads, model.param_elements, model.param_sizeof, 0);
float* converted_grads_memory = (float*)mallocCheck(model.num_parameters * sizeof(float));

// inputs and expected outputs, only used for error checking
// read reference information from the file saved from Python/PyTorch side
// 1) input x and y
int* x = (int*)mallocCheck(B * T * sizeof(int));
int* y = (int*)mallocCheck(B * T * sizeof(int));
float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));
float* expected_loss = (float*) mallocCheck(1 * sizeof(float));

// read reference information from Python
freadCheck(x, sizeof(int), B*T, state_file);
freadCheck(y, sizeof(int), B*T, state_file);
// 2) results of forward pass (logits and loss)
float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));
float* expected_loss = (float*) mallocCheck(1 * sizeof(float));
freadCheck(expected_logits, sizeof(float), B*T*V, state_file);
freadCheck(expected_loss, sizeof(float), 1, state_file);
// 3) results of backward pass (parameter gradients)
FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32
float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements);
freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);
fcloseCheck(state_file);

// this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads
void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes);
float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float));

// overall OK signal for the test
int allok = 1;

Expand All @@ -103,25 +162,32 @@ int main(int argc, char *argv[]) {
}
int logits_ok = 1;

// FP16 and lower require very high tolerances unfortunately
float accuracy_threshold = 1e-2;
// FP16 and lower require very high tolerances unfortunately. TODO look into more
float logit_accuracy_threshold = 1e-2f;
float loss_diff_threshold = 0.05f;
#if defined(ENABLE_BF16) || defined(ENABLE_F16)
accuracy_threshold = 23;
logit_accuracy_threshold = 15.0f;
#endif


float max_diff = 0.0f;
for (int i=0; i<B*T*V; i++) {
if(i < 3) {
if(i < 10) {
printf("%f %f\n", expected_logits[i], logits_cpu[i]);
}
if (fabsf(expected_logits[i] - logits_cpu[i]) >= accuracy_threshold) {
float diff = fabsf(expected_logits[i] - logits_cpu[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= logit_accuracy_threshold) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],logits_cpu[i]);
logits_ok = 0;
break;
}
}
allok = allok && logits_ok;
if(!logits_ok) { printf("NOT "); }
printf("OK (LOGITS)\n");
printf("logit max diff: %f\n", max_diff);

// let's do 10 training iterations, following the pytorch code
float losses[10];
Expand All @@ -137,71 +203,63 @@ int main(int argc, char *argv[]) {
if (step == 0) {
// error checking at step 0 for reference activations


allok = allok && logits_ok;
free(logits_cpu_raw);
free(logits_cpu);

// compare the achieved loss
if (fabsf(model.mean_loss - *expected_loss) >= accuracy_threshold) {
if (fabsf(model.mean_loss - *expected_loss) >= loss_diff_threshold) {
printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss);
allok = 0;
} else {
printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss);
}

// and now compare the gradients on the parameters
// cudaMemcpy(calculated_grads.lnfw, model.grads.lnfw, C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.lnfb, model.grads.lnfb, C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcprojw, model.grads.fcprojw, L * C * 4*C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcprojb, model.grads.fcprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcw, model.grads.fcw, L * 4*C * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcb, model.grads.fcb, L * 4*C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln2w, model.grads.ln2w, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln2b, model.grads.ln2b, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.attprojw, model.grads.attprojw, L * C * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.attprojb, model.grads.attprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.qkvw, model.grads.qkvw, L * 3*C * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.qkvb, model.grads.qkvb, L * 3*C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln1w, model.grads.ln1w, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln1b, model.grads.ln1b, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.wte, model.grads.wte, V * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.wpe, model.grads.wpe, maxT * C * sizeof(float), cudaMemcpyDeviceToHost);
// check_tensor(calculated_grads.lnfb, expected_grads.lnfb, C, "lnfb");
// check_tensor(calculated_grads.lnfw, expected_grads.lnfw, C, "lnfw");
// check_tensor(calculated_grads.fcprojw, expected_grads.fcprojw, L * C * 4*C, "fcprojw");
// check_tensor(calculated_grads.fcprojb, expected_grads.fcprojb, L * C, "fcprojb");
// check_tensor(calculated_grads.fcw, expected_grads.fcw, L * 4*C * C, "fcw");
// check_tensor(calculated_grads.fcb, expected_grads.fcb, L * 4*C, "fcb");
// check_tensor(calculated_grads.ln2w, expected_grads.ln2w, L * C, "ln2w");
// check_tensor(calculated_grads.ln2b, expected_grads.ln2b, L * C, "ln2b");
// check_tensor(calculated_grads.attprojw, expected_grads.attprojw, L * C * C, "attprojw");
// check_tensor(calculated_grads.attprojb, expected_grads.attprojb, L * C, "attprojb");
// check_tensor(calculated_grads.qkvw, expected_grads.qkvw, L * 3*C * C, "qkvw");
// check_tensor(calculated_grads.qkvb, expected_grads.qkvb, L * 3*C, "qkvb");
// check_tensor(calculated_grads.ln1w, expected_grads.ln1w, L * C, "ln1w");
// check_tensor(calculated_grads.ln1b, expected_grads.ln1b, L * C, "ln1b");
// check_tensor(calculated_grads.wte, expected_grads.wte, V * C, "wte");
// check_tensor(calculated_grads.wpe, expected_grads.wpe, maxT * C, "wpe");

// get gradients from GPU and convert all non-FP32 gradients back to FP32 for check_tensor
cudaMemcpy(calculated_grads_memory, model.grads_memory, model.num_parameters * sizeof(floatX), cudaMemcpyDeviceToHost);
char* src_iterator = (char*)calculated_grads_memory;
float* dst_iterator = (float*)converted_grads_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
// move the (mixed precision) grads from GPU to CPU
cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost);

// convert all gradients to float on the CPU
char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char*
float* dst_iterator = (float*)grads_memory_cpu_float; // float*
float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python
float* tensors1[NUM_PARAMETER_TENSORS];
float* tensors2[NUM_PARAMETER_TENSORS];
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
if (model.param_sizeof[i] == sizeof(float)) {
// float tensor => copy over directly
memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float));
} else {
assert(model.param_sizeof[i] == sizeof(floatX));
// low-precision tensor => convert to float
assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm
for (size_t j = 0; j < model.param_elements[i]; j++) {
dst_iterator[j] = ((floatX*)src_iterator)[j];
dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float
}
}
// for convenience record the position of comparison for reality vs. expectation
tensors1[i] = dst_iterator; // reality
tensors2[i] = exp_iterator; // expectation
// advance the iterators
src_iterator += model.param_elements[i] * model.param_sizeof[i];
dst_iterator += model.param_elements[i];
exp_iterator += model.param_elements[i];
}
// compare the gradients ona the parameters all at once
check_tensor(converted_grads_memory, expected_grads_memory, model.num_parameters, "grads");

// compare the gradients on the parameters all at once, in fp32
// I set the tolerances manually by inspecting the gradient differences for
// a few elements of each tensor. bf16 looks ok but not amazing here.
// It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure.
allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", 6e-1f);
allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", 1e-2f);
allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 9e-2); // hmm a bit high
allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 3e-2f);
allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", 3e-2f);
allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", 3e-2f);
allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", 9e-2f); // hmm a bit high
allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", 9e-2f); // hmm a bit high
allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", 9e-2f); // hmm a bit high
allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", 3e-2f);
allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", 0.1f); // hmm bit higher
allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", 3e-2f);
allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", 0.1f); // hmm bit higher
allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", 3e-2f);
allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", 0.12f); // hmm bit higher
allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", 3e-2f);
}

gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);
Expand All @@ -227,7 +285,7 @@ int main(int argc, char *argv[]) {

// compare
for (int i = 0; i < 10; i++) {
if (fabsf(losses[i] - expected_losses[i]) >= accuracy_threshold) {
if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) {
printf("LOSS MISMATCH AT STEP %d: %f %f\n", i, losses[i], expected_losses[i]);
allok = 0;
} else {
Expand All @@ -241,11 +299,13 @@ int main(int argc, char *argv[]) {
// free everything
free(x);
free(y);
free(logits_cpu_raw);
free(logits_cpu);
free(expected_logits);
free(expected_loss);
free(expected_grads_memory);
free(calculated_grads_memory);
free(converted_grads_memory);
free(grads_memory_cpu);
free(grads_memory_cpu_float);
gpt2_free(&model);
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
Expand Down
Loading

0 comments on commit d95b8d8

Please sign in to comment.