Skip to content

Commit

Permalink
Use pytest' tmp_path in test_irsource.py (#5145)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Nov 14, 2024
1 parent be81f0a commit 5ebd1e5
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions python/test/unit/tools/test_irsource.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import tempfile
import pathlib
import triton
from triton.compiler import IRSource
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()


def test_mlir_attribute_parsing() -> None:
def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None:
'''
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.
Expand Down Expand Up @@ -37,21 +37,20 @@ def test_mlir_attribute_parsing() -> None:
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir)
f.flush()
context = ir.context()
src = IRSource(f.name, context)
temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir"
temp_file.write_text(sample_ttgir)
context = ir.context()
src = IRSource(str(temp_file), context)

# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"
# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"

# check num warps
assert src.parse_options()['num_warps'] == 8
# check num warps
assert src.parse_options()['num_warps'] == 8

sample_ttgir_vector_add = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
Expand Down Expand Up @@ -83,11 +82,10 @@ def test_mlir_attribute_parsing() -> None:
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir_vector_add)
f.flush()
context = ir.context()
src = IRSource(f.name, context)
temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir"
temp_file.write_text(sample_ttgir_vector_add)
context = ir.context()
src = IRSource(str(temp_file), context)

# now test compilation
triton.compile(f.name, target=target)
# now test compilation
triton.compile(str(temp_file), target=target)

0 comments on commit 5ebd1e5

Please sign in to comment.