From 56d0e3b7affce768ae8c7b9d17743d3e8332308d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 19 Mar 2023 14:50:44 -0400 Subject: [PATCH] [METAL][CODEGEN] testcase for ramp codegen (#14331) This PR adds a testcase that can be tested locally to cover metal ramp codegen --- .../unittest/test_target_codegen_metal.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index 002cf3c69640..45588c69cf2a 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -17,11 +17,12 @@ import tvm from tvm import te import numpy as np -from tvm import topi -import unittest + from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 from tvm.contrib import nvcc import tvm.testing +import tvm.script +from tvm.script import tir as T tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -76,6 +77,30 @@ def check_erf(dev, n, dtype): check_erf(dev, 1, "float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_metal +def test_ramp(): + target = "metal" + + @tvm.script.ir_module + class IRModule: + @T.prim_func + def main(A: T.Buffer((1, 2), "int32")): + T.func_attr({"global_symbol": "main"}) + for i in T.thread_binding(1, thread="threadIdx.x"): + with T.block("block"): + tx = T.axis.spatial(1, i) + r = T.ramp(tx, 3, 2) + A[0, T.ramp(0, 1, 2)] = r + + f = tvm.build(IRModule, target=target) + dev = tvm.metal() + a_nd = tvm.nd.empty((1, 2), "int32", dev) + f(a_nd) + assert tuple(a_nd.numpy()[0, :]) == (0, 3) + + if __name__ == "__main__": + test_ramp() test_metal_inf_nan() test_metal_erf()