Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Working Grouped gemm with group ID #48

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1825ef8
Cutlass grouped gemm files
ElizaWszola Dec 6, 2024
5fd48e5
runs, bad result
ElizaWszola Dec 9, 2024
d5942cf
A little closer to working
ElizaWszola Dec 10, 2024
c570c69
Working for identical sizes
ElizaWszola Dec 11, 2024
6ed63f2
Grouped gemm working
ElizaWszola Dec 17, 2024
e2b1fc0
Small cleanup
ElizaWszola Dec 17, 2024
dd163f5
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 8, 2025
acfd3ef
Benchmark grouped cutlass against bfloat16 torch.mm
ElizaWszola Jan 13, 2025
c6231b6
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 13, 2025
f1a5666
Start working on fused moe cutlass implementation
ElizaWszola Jan 17, 2025
6414e31
Working halfway
ElizaWszola Jan 20, 2025
67e2dd4
working mul test but the topk_weights are not yet included in kernel
ElizaWszola Jan 23, 2025
6523529
cleaned up cutlass moe test, fixes
ElizaWszola Jan 23, 2025
b302d98
benchmark fused
ElizaWszola Jan 23, 2025
342d1a4
pass input as one tensor with an array of offsets rather than a list …
ElizaWszola Jan 24, 2025
7549e3d
Using tensors rather than tensor lists works with test_cutlass test
ElizaWszola Jan 28, 2025
64c2a68
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 28, 2025
1ea7874
cleanup, add import
ElizaWszola Jan 28, 2025
d608164
working fused op
ElizaWszola Jan 29, 2025
286f6c8
benchmark, create strides directly on device, small name refactor
ElizaWszola Jan 29, 2025
b6867bb
works with cuda graphs
ElizaWszola Jan 31, 2025
df04bc0
move stride tensor creation outside c++ code, cleanup
ElizaWszola Jan 31, 2025
88c7134
cleanup benchmark
ElizaWszola Jan 31, 2025
02e1d4e
profile
ElizaWszola Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down
255 changes: 255 additions & 0 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from typing import List, Tuple

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe,
fused_experts,
fused_topk)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = [
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
]
DEFAULT_BATCH_SIZES = [16, 32, 64, 128, 256, 512]

PER_ACT_TOKEN_OPTS = [False] #[False, True]
PER_OUT_CH_OPTS = [False] #[False, True]
TOPKS = [2, 6]


def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


def bench_run(results: List[benchmark.Measurement], model: str,
num_experts: int, topk: int, per_act_token: bool,
per_out_ch: bool, mkn: Tuple[int, int, int]):
label = "Quant Matmul"

sub_label = ("{}, num_experts={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_experts, per_act_token,
per_out_ch, mkn))

print(f"Testing: {sub_label}")

(m, k, n) = mkn

dtype = torch.half

a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

a_q, a_scale = ops.scaled_fp8_quant(a)

w1_q = torch.empty((num_experts, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((num_experts, k, n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)

def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a_scale: torch.Tensor, num_repeats: int):
for _ in range(num_repeats):
fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)

def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
m: int, n: int, k: int, num_experts: int,
num_repeats: int):
for _ in range(num_repeats):
cutlass_moe(a, a_scale, w1, w2, w1_scale, w2_scale, topk_weights,
topk_ids, m, n, k, num_experts)

def run_from_graph(a_q: torch.Tensor, a_scale: torch.Tensor,
w1_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
m: int, n: int, k: int, e: int):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, m, n, k, e)

def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
graph.replay()
torch.cuda.synchronize()

stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
run_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, m, n, k, num_experts)
torch.cuda.synchronize()

min_run_time = 5
num_warmup = 5

globals = {
# Baseline params
"a": a,
"w1": w1,
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_q": a_q,
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"m": m,
"n": n,
"k": k,
"num_experts": num_experts,
# Cutlass cuda graph params
"graph": graph,
# Gen params
"topk_weights": topk_weights,
"topk_ids": topk_ids,
# Kernels
"run_triton_moe": run_triton_moe,
"run_cutlass_moe": run_cutlass_moe,
"replay_graph": replay_graph,
}

# Warmup
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
w1_scale, w2_scale, a_scale, num_warmup)

results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, 1)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
topk_ids, m, n, k, num_experts, num_warmup)

results.append(
benchmark.Timer(
stmt=
"run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts, 1)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
replay_graph(graph, num_warmup)

results.append(
benchmark.Timer(
stmt="replay_graph(graph, 1)",
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))


def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

results: List[benchmark.Measurement] = []

for model in args.models:
for layer in WEIGHT_SHAPES_MOE[model]:
num_experts = layer[0]
size_k = layer[1]
size_n = layer[2]

if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue

if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue

for per_act_token in PER_ACT_TOKEN_OPTS:
for per_out_ch in PER_OUT_CH_OPTS:
for topk in TOPKS:
for size_m in DEFAULT_BATCH_SIZES:
mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk,
per_act_token, per_out_ch, mkn)

compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token",
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

args = parser.parse_args()
main(args)
16 changes: 16 additions & 0 deletions benchmarks/kernels/benchmark_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,19 @@
[7168, 8192],
],
}

WEIGHT_SHAPES_MOE = {
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
[8, 4096, 28672],
[8, 14336, 4096],
],
"nm-testing/deepseekv2-lite": [
[64, 2048, 1408],
],
"ibm-granite/granite-3.0-1b-a400m": [
[32, 1024, 1024],
],
"ibm-granite/granite-3.0-3b-a800m": [
[40, 1024, 1536],
],
}
1 change: 1 addition & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);

// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
Expand Down
Loading