Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
vyom1611 authored Aug 14, 2024
2 parents fe68f8f + 4c84bc7 commit 7f319e8
Show file tree
Hide file tree
Showing 25 changed files with 2,003 additions and 252 deletions.
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ REMOVE_FILES = rm -f
OUTPUT_FILE = -o $@
CUDA_OUTPUT_FILE = -o $@

# Default O3 CPU optimization level for NVCC (0 for fastest compile time)
FORCE_NVCC_O ?= 3

# NVCC flags
# -t=0 is short for --threads, 0 = number of CPUs on the machine
NVCC_FLAGS = -O3 -t=0 --use_fast_math -std=c++17
NVCC_FLAGS = --threads=0 -t=0 --use_fast_math -std=c++17 -O$(FORCE_NVCC_O)
NVCC_LDFLAGS = -lcublas -lcublasLt
NVCC_INCLUDES =
NVCC_LDLIBS =
Expand Down Expand Up @@ -45,8 +48,8 @@ endif

ifneq ($(CI),true) # if not in CI, then use the GPU query
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
ifneq ($(call file_exists_in_path, __nvcc_device_query),)
GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query)
ifneq ($(call file_exists_in_path, nvidia-smi),)
GPU_COMPUTE_CAPABILITY = $(shell nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g')
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
endif
endif
Expand All @@ -62,6 +65,7 @@ $(info ---------------------------------------------)

ifneq ($(OS), Windows_NT)
NVCC := $(shell which nvcc 2>/dev/null)
NVCC_LDFLAGS += -lnvidia-ml

# Function to test if the compiler accepts a given flag.
define check_and_add_flag
Expand Down
11 changes: 7 additions & 4 deletions dev/cuda/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ ifeq ($(NVCC),)
endif

ifneq ($(CI),true) # if not in CI, then use the GPU query
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
endif
endif

# Compiler flags
ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=
CFLAGS = -O3 --use_fast_math
ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=
CFLAGS = -O3 --use_fast_math
else
CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)]
endif
Expand All @@ -30,7 +30,8 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@

# Build all targets
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute

all: $(TARGETS)
all_ptx: $(TARGETS:%=%.ptx)
all_sass: $(TARGETS:%=%.sass)
Expand Down Expand Up @@ -64,6 +65,8 @@ matmul_backward: matmul_backward.cu
adamw: adamw.cu
global_norm: global_norm.cu

permute: permute.cu

# NCCL communication kernels
nccl_all_reduce: nccl_all_reduce.cu
$(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce
Expand Down
1 change: 1 addition & 0 deletions dev/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ int main(int argc, char **argv) {
free(dinp);
free(dpreatt);
free(datt);
free(h_dinp);
cudaCheck(cudaFree(d_inp));
cudaCheck(cudaFree(d_qkvr));
cudaCheck(cudaFree(d_preatt));
Expand Down
1 change: 1 addition & 0 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,7 @@ int main(int argc, char **argv) {
cudaCheck(cudaFree(d_preatt));
cudaCheck(cudaFree(d_att));
cudaCheck(cudaFree(d_inp));
cudaCheck(cudaFree(d_stats));
cublasDestroy(cublas_handle);

#ifdef ENABLE_CUDNN
Expand Down
1 change: 1 addition & 0 deletions dev/cuda/classifier_fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ int main(int argc, char **argv) {
cudaCheck(cudaFree(d_logits));
cudaCheck(cudaFree(d_dlosses));
cudaCheck(cudaFree(d_targets));
cudaCheck(cudaFree(d_dlogits_no_pad));

return 0;
}
1 change: 1 addition & 0 deletions dev/cuda/nccl_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,6 @@ int main(int argc, char **argv) {

free(all_reduce_buffer_host);
cudaCheck(cudaFree(all_reduce_buffer));
cudaCheck(cudaFree(all_reduce_buffer_recv));
multi_gpu_config_free(&multi_gpu_config);
}
181 changes: 181 additions & 0 deletions dev/cuda/permute.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
Kernels to demonstrate permute operation.
Compile example:
nvcc -O3 permute.cu -o permute
The goal is to permute a 4D matrix from its original shape (dim1, dim2, dim3, dim4) to a new shape (dim4, dim3, dim1, dim2).
Before permutation, we need to understand how to access elements in a flattened (linear) form of the matrix.
Given:
dim1 = size of the 1st dimension
dim2 = size of the 2nd dimension
dim3 = size of the 3rd dimension
dim4 = size of the 4th dimension
For any element in a 4D matrix at position (i1, i2, i3, i4), where:
i1 is the index in dimension 1
i2 is the index in dimension 2
i3 is the index in dimension 3
i4 is the index in dimension 4
If you find it challenging to calculate the indices i1, i2, i3, and i4, observe the pattern in the index calculations.
Initially, it might take some time to grasp, but with practice, you'll develop a mental model for it.
To calculate the indices, use the following formulas:
i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
i2 = (idx / (dim3 * dim4)) % dim2;
i3 = (idx / dim4) % dim3;
i4 = idx % dim4;
Pattern Explanation:
To find the index for any dimension, divide the thread ID (idx) by the product of all subsequent dimensions.
Then, perform modulo operation with the current dimension.
The linear index in a flattened 1D array is calculated as:
linear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4
This linear index uniquely identifies the position of the element in the 1D array.
To permute the matrix, we need to rearrange the indices according to the new shape.
In this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2).
The new dimension post permutation will be as follows:
dim1 becomes the new 3rd dimension.
dim2 becomes the new 4th dimension.
dim3 becomes the new 2nd dimension.
dim4 becomes the new 1st dimension.
permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
Here's how this works:
i4 * (dim3 * dim1 * dim2): This accounts for how many complete dim3 × dim1 × dim2 blocks fit before the current i4 block.
i3 * (dim1 * dim2): This accounts for the offset within the current i4 block, specifying which i3 block we are in.
i1 * dim2: This accounts for the offset within the current i3 block, specifying which i1 block we are in.
i2: This gives the offset within the current i1 block.
Lastly at the end we store the current value at idx index of the original value to the permuted index in the permuted_matrix.
--------------------------------------------------------------------------------------------------------------------------------------------------------
Similarly we can follow the above approach to permute matrices of any dimensions.
*/


#include <cuda_runtime.h>
#include <stdio.h>
#include <stdlib.h>
#include <cmath>

#include "common.h"

// CPU function to permute a 4D matrix
void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
int total_threads = dim1 * dim2 * dim3 * dim4;

for (int idx = 0; idx < total_threads; idx++) {
// Calculate the 4D indices from the linear index
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
int i2 = (idx / (dim3 * dim4)) % dim2;
int i3 = (idx / dim4) % dim3;
int i4 = idx % dim4;

// Compute the new index for the permuted matrix
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
out_matrix[permuted_idx] = matrix[idx];
}
}

// CUDA kernel to permute a 4D matrix
__global__ void permute_kernel(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;

// Ensure index is within bounds
if (idx < dim1 * dim2 * dim3 * dim4) {
// Calculate the 4D indices from the linear index
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
int i2 = (idx / (dim3 * dim4)) % dim2;
int i3 = (idx / dim4) % dim3;
int i4 = idx % dim4;

// Compute the new index for the permuted matrix
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
out_matrix[permuted_idx] = matrix[idx];
}
}


int main() {
int dim_1 = 24;
int dim_2 = 42;
int dim_3 = 20;
int dim_4 = 32;

// Set up the device
int deviceIdx = 0;
cudaSetDevice(deviceIdx);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

// Allocate host memory
float* matrix = make_random_float(dim_1 * dim_2 * dim_3 * dim_4);
float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));

// Initialize the matrix with random values

// Allocate device memory
float *d_matrix, *d_permuted_matrix;
cudaMalloc(&d_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
cudaMalloc(&d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));

// Copy matrix from host to device
cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice);

// Perform permutation on CPU
clock_t start = clock();
permute_cpu(matrix, permuted_matrix, dim_1, dim_2, dim_3, dim_4);
clock_t end = clock();
double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC;

// Define block and grid sizes
dim3 blockSize(256);
int totalThreads = dim_1 * dim_2 * dim_3 * dim_4;
int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size

// Launch CUDA kernel to perform permutation
permute_kernel<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);
cudaDeviceSynchronize(); // Ensure kernel execution is complete

// Verify results
printf("Checking correctness...\n");
validate_result(d_permuted_matrix, permuted_matrix, "permuted_matrix", dim_1 * dim_2 * dim_3 * dim_4, 1e-5f);

printf("All results match.\n\n");
// benchmark kernel
int repeat_times = 1000;
float elapsed_time = benchmark_kernel(repeat_times, permute_kernel,
d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4
);
printf("time gpu %.4f ms\n", elapsed_time);
printf("time cpu %.4f ms\n", elapsed_time_cpu);

// Free allocated memory
free(matrix);
free(permuted_matrix);
cudaFree(d_matrix);
cudaFree(d_permuted_matrix);

return 0;
}
1 change: 1 addition & 0 deletions dev/cuda/trimat_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ int main(int argc, char **argv) {
free(inp);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_inp));
cudaCheck(cudaFree(d_qkvr));
cublasDestroy(cublas_handle);

return 0;
Expand Down
2 changes: 2 additions & 0 deletions dev/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ The idea is that each dataset has a .py file here in the root of `dev/data`, and
- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it

And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.

Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.
40 changes: 25 additions & 15 deletions dev/data/data_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,38 @@ def download_file(url: str, fname: str, chunk_size=1024):
bar.update(size)


def write_datafile(filename, toks):
HEADERS_INFO = {
"gpt-2": {
"magic": 20240520,
"version": 1,
"token_dtype": np.uint16,
},
"llama-3": {
"magic": 20240801,
"version": 7,
"token_dtype": np.uint32,
},
}

def write_datafile(filename, toks, model_desc="gpt-2"):
"""
Saves token data as a .bin file, for reading in C.
- First comes a header with 256 int32s
- The tokens follow, each as a uint16
- The tokens follow, each as uint16 (gpt-2) or uint32 (llama)
"""
assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
assert model_desc in ["gpt-2", "llama-3"], f"unknown model descriptor {model_desc}"
info = HEADERS_INFO[model_desc]
# construct the header
header = np.zeros(256, dtype=np.int32)
header[0] = 20240520 # magic
header[1] = 1 # version
header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
# construct the tokens numpy array, if not already
if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
# validate that no token exceeds a uint16
maxtok = 2**16
assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
toks_np = np.array(toks, dtype=np.uint16)
else:
toks_np = toks
header = np.zeros(256, dtype=np.int32) # header is always 256 int32 values
header[0] = info["magic"]
header[1] = info["version"]
header[2] = len(toks) # number of tokens after the 256*4 bytes of header
# construct the data (numpy array of tokens)
toks_np = np.array(toks, dtype=info["token_dtype"])
# write to file
print(f"writing {len(toks):,} tokens to {filename}")
num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize)
print(f"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format")
with open(filename, "wb") as f:
f.write(header.tobytes())
f.write(toks_np.tobytes())
Expand Down
Loading

0 comments on commit 7f319e8

Please sign in to comment.