Skip to content

Commit

Permalink
metal : add TODOs for rest of ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 12, 2024
1 parent 964206a commit 63bab93
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1481,10 +1481,10 @@ static void ggml_metal_encode_node(
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&min length:sizeof(min) atIndex:2];
[encoder setBytes:&max length:sizeof(max) atIndex:3];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&min length:sizeof(min) atIndex:2];
[encoder setBytes:&max length:sizeof(max) atIndex:3];

const int64_t n = ggml_nelements(dst);

Expand Down Expand Up @@ -1656,6 +1656,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -1731,6 +1732,8 @@ static void ggml_metal_encode_node(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

// TODO: add ggml_metal_kargs struct
// TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
Expand All @@ -1747,6 +1750,7 @@ static void ggml_metal_encode_node(
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];

[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
Expand All @@ -1763,6 +1767,7 @@ static void ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
}

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand All @@ -1787,6 +1792,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
Expand Down Expand Up @@ -1857,6 +1863,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
Expand Down Expand Up @@ -2595,6 +2602,7 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
}

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
Expand Down Expand Up @@ -2664,6 +2672,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -2853,6 +2862,7 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error");
};

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -2893,6 +2903,7 @@ static void ggml_metal_encode_node(

const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -2927,6 +2938,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -2963,6 +2975,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
Expand All @@ -2984,6 +2997,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -3022,6 +3036,7 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error");
};

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand All @@ -3040,6 +3055,7 @@ static void ggml_metal_encode_node(

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down Expand Up @@ -3517,6 +3533,7 @@ static void ggml_metal_encode_node(
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;

// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
Expand Down

0 comments on commit 63bab93

Please sign in to comment.