forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtune_gpt2cl.c
163 lines (137 loc) · 6.16 KB
/
tune_gpt2cl.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#define TESTING
#include "train_gpt2cl.c"
#define NUM_STEPS 6
#define SKIP_STEPS 3
int do_run(double *time_taken) {
// build the GPT-2 model from a checkpoint
GPT2 model;
GPT2_CL gcl;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");
int C = model.config.channels;
int V = model.config.vocab_size;
int Vp = model.config.padded_vocab_size;
int maxT = model.config.max_seq_len;
int L = model.config.num_layers;
// load additional information that we will use for debugging and error checking
FILE *state_file = fopen("gpt2_124M_debug_state.bin", "rb");
if (state_file == NULL) { printf("Error opening state file\n"); return 1; }
int state_header[256];
fread(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file\n"); return 1; }
if (state_header[1] != 2) {
printf("Bad version in state file\n");
printf("---> HINT: try to re-run `python train_gpt2.py`\n");
return 1;
}
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
// inputs and expected outputs, only used for error checking
int* x = (int*) malloc(B * T * sizeof(int));
int* y = (int*) malloc(B * T * sizeof(int));
// read reference information from Python
fread(x, sizeof(int), B*T, state_file);
fread(y, sizeof(int), B*T, state_file);
fclose(state_file);
int clret = cl_init(&gcl, B, T, C, Vp);
if (clret != 0) {
printf("error initializing opencl\n");
free(x);
free(y);
gpt2_free(&model);
cl_deinit(&gcl);
return clret;
}
struct timespec start, end;
double total_time = 0.0;
for (int step = 0; step < NUM_STEPS; step++) {
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&gcl, &model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&gcl, &model);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
// consider after skip steps for warmup
if(step >= SKIP_STEPS) {
total_time += time_elapsed_s;
}
}
*time_taken = total_time * 1000 / (NUM_STEPS - SKIP_STEPS);
// free everything
free(x);
free(y);
gpt2_free(&model);
cl_deinit(&gcl);
return 0;
}
int main(int argc, char *argv[]) {
double time_taken = 0.0;
char str[16];
int tile_size_lst[] = {4, 8, 12, 16, 24, 32, 48, 64};
int lmp_size_lst[] = {0, 1};
int vload_size_lst[] = {0, 4, 8, 16};
int do_preload_lst[] = {0, 1};
int use_mad_lst[] = {0, 1};
int len_tile_size_lst = sizeof(tile_size_lst)/sizeof(tile_size_lst[0]);
int len_lmp_size_lst = sizeof(lmp_size_lst)/sizeof(lmp_size_lst[0]);
int len_vload_size_lst = sizeof(vload_size_lst)/sizeof(vload_size_lst[0]);
int len_do_preload_lst = sizeof(do_preload_lst)/sizeof(do_preload_lst[0]);
int len_use_mad_lst = sizeof(use_mad_lst)/sizeof(use_mad_lst[0]);
double best_time_taken = 1e10;
int best_tile_size = 0;
int best_lmp_size = 0;
int best_vload_size = 0;
int best_do_preload = 0;
int best_use_mad = 0;
int total_count = len_tile_size_lst * len_lmp_size_lst * len_vload_size_lst * len_do_preload_lst * len_use_mad_lst;
int count = 0;
for(int ti=0; ti<len_tile_size_lst; ti++) {
int tile_size = tile_size_lst[ti];
snprintf(str, sizeof(str), "%d", tile_size);
setenv("MATMUL_TILE_SIZE", str, 1);
for(int lmpi=0; lmpi<len_lmp_size_lst; lmpi++) {
int lmp_size = lmp_size_lst[lmpi];
snprintf(str, sizeof(str), "%d", lmp_size);
setenv("MATMUL_LOCAL_MEM_PADDING_SIZE", str, 1);
for(int vli=0; vli<len_vload_size_lst; vli++) {
int vload_size = vload_size_lst[vli];
snprintf(str, sizeof(str), "%d", vload_size);
setenv("MATMUL_VLOAD_SIZE", str, 1);
for(int dpli=0; dpli<len_do_preload_lst; dpli++) {
int do_preload = do_preload_lst[dpli];
snprintf(str, sizeof(str), "%d", do_preload);
setenv("MATMUL_DO_PRELOAD", str, 1);
for(int umadi=0; umadi<len_use_mad_lst; umadi++) {
int use_mad = use_mad_lst[umadi];
snprintf(str, sizeof(str), "%d", use_mad);
setenv("MATMUL_USE_MAD", str, 1);
printf("MATMUL_TILE_SIZE=%d MATMUL_LOCAL_MEM_PADDING_SIZE=%d MATMUL_VLOAD_SIZE=%d MATMUL_DO_PRELOAD=%d MATMUL_USE_MAD=%d\n",
tile_size, lmp_size, vload_size, do_preload, use_mad);
printf("--------------------------------------------- (%d/%d)\n", count+1, total_count);
int ret = do_run(&time_taken);
if (ret == 0) {
printf("---------------------------------------------\n");
printf("time taken: %lf ms\n", time_taken);
if(time_taken < best_time_taken) {
best_time_taken = time_taken;
best_tile_size = tile_size;
best_lmp_size = lmp_size;
best_vload_size = vload_size;
best_do_preload = do_preload;
best_use_mad = use_mad;
}
} else {
printf("skipping\n");
}
printf("---------------------------------------------\n");
count++;
}
}
}
}
}
printf("\nbest time taken: %lf ms with combination\n", best_time_taken);
printf("MATMUL_TILE_SIZE=%d MATMUL_LOCAL_MEM_PADDING_SIZE=%d MATMUL_VLOAD_SIZE=%d MATMUL_DO_PRELOAD=%d MATMUL_USE_MAD=%d\n",
best_tile_size, best_lmp_size, best_vload_size, best_do_preload, best_use_mad);
return 0;
}