Skip to content

Commit

Permalink
llama: rwkv6: Use the new advanced batch splits
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Aug 23, 2024
1 parent 9ffa40d commit c3564d8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 204 deletions.
10 changes: 1 addition & 9 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ extern "C" {
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV,
GGML_OP_RWKV_TOKEN_SHIFT,

GGML_OP_UNARY,

Expand Down Expand Up @@ -1857,14 +1856,7 @@ extern "C" {
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state,
struct ggml_tensor * state_seq);

GGML_API struct ggml_tensor * ggml_rwkv_token_shift(
struct ggml_context * ctx,
struct ggml_tensor * x_carry,
struct ggml_tensor * x_norm,
struct ggml_tensor * state_seq);
struct ggml_tensor * state);

// custom operators

Expand Down
156 changes: 11 additions & 145 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GET_REL_POS",
"ADD_REL_POS",
"RWKV_WKV",
"RWKV_TOKEN_SHIFT",

"UNARY",

Expand All @@ -2836,7 +2835,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};

static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -2906,8 +2905,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"win_unpart(x)",
"get_rel_pos(x)",
"add_rel_pos(x)",
"rwkv_wkv(k, v, r, tf, td, s, sq)",
"rwkv_token_shift(xc, xn, sq)",
"rwkv_wkv(k, v, r, tf, td, s)",

"unary(x)",

Expand All @@ -2926,7 +2924,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};

static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -7494,39 +7492,36 @@ struct ggml_tensor * ggml_rwkv_wkv(
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state,
struct ggml_tensor * state_seq) {
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(tf));
GGML_ASSERT(ggml_is_contiguous(td));
GGML_ASSERT(ggml_is_contiguous(state));
GGML_ASSERT(ggml_is_contiguous(state_seq));
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);

const int64_t S = k->ne[0];
const int64_t H = k->ne[2];
const int64_t n_tokens = k->ne[3];
const int64_t n_kv = state_seq->ne[0];
const int64_t n_seqs = state->ne[1];
{
GGML_ASSERT(k->ne[1] == 1);
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
// TODO: RWKV v4 and v5
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}

bool is_node = false;

if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad) {
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}

// concat output and new_state
const int64_t ne[4] = { S * H, n_tokens + S * n_kv, 1, 1 };
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

result->op = GGML_OP_RWKV_WKV;
Expand All @@ -7537,48 +7532,6 @@ struct ggml_tensor * ggml_rwkv_wkv(
result->src[3] = tf;
result->src[4] = td;
result->src[5] = state;
result->src[6] = state_seq;

return result;
}

// ggml_rwkv_token_shift

struct ggml_tensor * ggml_rwkv_token_shift(
struct ggml_context * ctx,
struct ggml_tensor * x_carry,
struct ggml_tensor * x_norm,
struct ggml_tensor * state_seq) {
GGML_ASSERT(ggml_is_contiguous(x_carry));
GGML_ASSERT(ggml_is_contiguous(x_norm));
GGML_ASSERT(ggml_is_contiguous(state_seq));
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);

const int64_t n_embd = x_norm->ne[0];
const int64_t n_kv = state_seq->ne[0];
const int64_t n_tokens = state_seq->ne[1];
{
GGML_ASSERT(x_norm->ne[0] == n_embd);
GGML_ASSERT(x_norm->ne[1] == n_tokens);
GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
}

bool is_node = false;

if (x_carry->grad || x_norm->grad || state_seq->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}

// concat output and new_state
const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

result->op = GGML_OP_RWKV_TOKEN_SHIFT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = x_carry;
result->src[1] = x_norm;
result->src[2] = state_seq;

return result;
}
Expand Down Expand Up @@ -16418,7 +16371,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
const size_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2];
const size_t n_kv = dst->src[6]->ne[0];
const size_t n_seqs = dst->src[5]->ne[1];

float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
Expand All @@ -16434,8 +16387,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float * r = (float *) dst->src[2]->data;
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
int32_t * seq_data = (int32_t *) dst->src[6]->data;
memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));

size_t t_stride = H * (C / H);

Expand All @@ -16448,7 +16400,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
float * state_cur = state + (C / H) * C * seq_data[t * n_kv];
float * state_cur = state + (C / H) * C * (t / (T / n_seqs));

for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
Expand Down Expand Up @@ -16480,15 +16432,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
}
}
}

for (size_t t = 0; t < T; t++) {
for (size_t kv = 1; kv < n_kv; kv++) {
int64_t seq = seq_data[t * n_kv + kv];
if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
}
}
}
}

static void ggml_compute_forward_rwkv_wkv(
Expand All @@ -16509,77 +16452,6 @@ static void ggml_compute_forward_rwkv_wkv(
}
}

static void ggml_compute_forward_rwkv_token_shift_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t n_embd = dst->ne[0];
const int64_t n_kv = dst->src[2]->ne[0];
const int64_t n_tokens = dst->src[1]->ne[1];
float * dst_data = (float *) dst->data;
float * x_carry = (float *) dst->src[0]->data;
float * x_norm = (float *) dst->src[1]->data;
int32_t * sq_data = (int32_t *) dst->src[2]->data;

if (params->ith != 0) {
return;
}

int32_t seq_start = 0;
int32_t seq_length = 0;

for (int i1 = 0; i1 < n_kv; ++i1) {
seq_start = -1;
// assume that the tokens for each sequence are contiguous
for (int i2 = 0; i2 < n_tokens; ++i2) {
int32_t seq = sq_data[i2*n_kv];
if (seq == i1 && seq_start < 0) {
seq_start = i2;
}

if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
seq_length = i2 - seq_start + (i2 == n_tokens - 1);
break;
}
}

if (seq_start >= 0) {
int32_t seq = sq_data[seq_start*n_kv];
memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
}
}

for (int i3 = 0; i3 < n_kv; ++i3) {
int32_t last_token_pos = 0;
for (int i4 = 0; i4 < n_tokens; ++i4) {
for (int i5 = 0; i5 < n_kv; ++i5) {
if (sq_data[i4*n_kv + i5] == i3) {
last_token_pos = i4;
}
}
}
memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
}
}

static void ggml_compute_forward_rwkv_token_shift(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_token_shift_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_map_unary

static void ggml_compute_forward_map_unary_f32(
Expand Down Expand Up @@ -17230,10 +17102,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv(params, tensor);
} break;
case GGML_OP_RWKV_TOKEN_SHIFT:
{
ggml_compute_forward_rwkv_token_shift(params, tensor);
} break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
Expand Down Expand Up @@ -18305,7 +18173,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_GET_REL_POS:
case GGML_OP_ADD_REL_POS:
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_TOKEN_SHIFT:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
Expand Down Expand Up @@ -18876,7 +18743,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_TOKEN_SHIFT:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
Expand Down
Loading

0 comments on commit c3564d8

Please sign in to comment.