Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[METAL][CODEGEN] testcase for ramp codegen #14331

Merged
merged 1 commit into from
Mar 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import tvm
from tvm import te
import numpy as np
from tvm import topi
import unittest

from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
import tvm.script
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -76,6 +77,30 @@ def check_erf(dev, n, dtype):
check_erf(dev, 1, "float16")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_ramp():
target = "metal"

@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((1, 2), "int32")):
T.func_attr({"global_symbol": "main"})
for i in T.thread_binding(1, thread="threadIdx.x"):
with T.block("block"):
tx = T.axis.spatial(1, i)
r = T.ramp(tx, 3, 2)
A[0, T.ramp(0, 1, 2)] = r

f = tvm.build(IRModule, target=target)
dev = tvm.metal()
a_nd = tvm.nd.empty((1, 2), "int32", dev)
f(a_nd)
assert tuple(a_nd.numpy()[0, :]) == (0, 3)


if __name__ == "__main__":
test_ramp()
test_metal_inf_nan()
test_metal_erf()