Skip to content

Commit

Permalink
[METAL][CODEGEN] testcase for ramp codegen (#14331)
Browse files Browse the repository at this point in the history
This PR adds a testcase that can be tested locally to cover metal ramp codegen
  • Loading branch information
tqchen authored Mar 19, 2023
1 parent 542274d commit 56d0e3b
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import tvm
from tvm import te
import numpy as np
from tvm import topi
import unittest

from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
import tvm.script
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -76,6 +77,30 @@ def check_erf(dev, n, dtype):
check_erf(dev, 1, "float16")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_ramp():
target = "metal"

@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((1, 2), "int32")):
T.func_attr({"global_symbol": "main"})
for i in T.thread_binding(1, thread="threadIdx.x"):
with T.block("block"):
tx = T.axis.spatial(1, i)
r = T.ramp(tx, 3, 2)
A[0, T.ramp(0, 1, 2)] = r

f = tvm.build(IRModule, target=target)
dev = tvm.metal()
a_nd = tvm.nd.empty((1, 2), "int32", dev)
f(a_nd)
assert tuple(a_nd.numpy()[0, :]) == (0, 3)


if __name__ == "__main__":
test_ramp()
test_metal_inf_nan()
test_metal_erf()

0 comments on commit 56d0e3b

Please sign in to comment.