Skip to content

Commit

Permalink
[inductor] Use int64_t as index type for all platfroms 4 (pytorch#133892
Browse files Browse the repository at this point in the history
)

It is parallel PR to pytorch#133819 , and it is append change for @jansel 's comments.
1. For `torch/_inductor/codegen/cpp_wrapper_cpu.py`, revert to origin code to append LL on MacOS and Windows: pytorch@bdc14ad
2. For `torch/_inductor/codegen/cpp_utils.py`, append LL on MacOS and Windows forlarge constants. And fix its UTs: pytorch@3a56b76

------------------------------
Another solution for pytorch#133615, use `int64_t` as index type for all plartform.

### Development notes:
The metioned PR( pytorch#133615) is fix the index type not match to parse_arg args types. As reviewed with @jansel , Jason think we need to unificate `INDEX_TYPE` for all platforms.
Current code is make code cumbersome:
```python
INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long"
```

So, I have some attempts to unificate `INDEX_TYPE` as `long` or `int64_t`.
For use `long` as index type: pytorch#133768
For use `int64_t` as index type: pytorch#133782

Since that, we still discussed which type we will select as final solution.
![image](https://github.com/user-attachments/assets/b23fa577-2d40-4bd6-b934-fb7994fe0bb0)

`long` type is different define and size in different OSs and different compilers. So, @jansel make decision that, we need to select `int64_t` for all platforms. So, I would comtine my work based on pytorch#133782.

As pytorch#133782 still has two issues:
1. std::min/std::max could not match function instances by arg types. It as fixed and validated in PR: pytorch#133812
4. Cuda TestMemoryPlanning::test_cpp_wrapper issue by wrong index type. It is fixing in this PR.

So, we made final solution in this PR.

### Changes:
**1. Use `int64_t` type as index type for all OSs: `Windows`, `Linux` and `MacOS`.**
**2. Use static_cast<int64_t>(`constant`) to convert constant to `div_floor_integer` with args type(`int64_t`).**
**3. Update `parse_arg` function signature to `int64_t`, which follow the index type.**
**4. Append double L(`LL`) to constant on Windows and MacOS, because of their int64_t are are long long.**
**5. Fix `std::min/std::max` type miss match by static_cast to `INDEX_TYPE`.**
**6. Fix UTs, containts: cuda `TestMemoryPlanning::test_cpp_wrapper`, and `test_indexing.py`.**

Pull Request resolved: pytorch#133892
Approved by: https://github.com/jansel
  • Loading branch information
xuhancn authored and pytorchmergebot committed Aug 20, 2024
1 parent 3caf3ba commit fbf3fc2
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 17 deletions.
38 changes: 30 additions & 8 deletions test/inductor/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import os
import sys
import unittest

import sympy
Expand Down Expand Up @@ -262,14 +263,19 @@ def test_print_pow(self):
cpu_cases = common_cases + [
(
sympy.Pow(s1 + s2, 2),
lambda c, L: "static_cast<long>((bar + foo)*(bar + foo))",
lambda c, L: "static_cast<int64_t>((bar + foo)*(bar + foo))",
)
]
for expr, result in gpu_cases:
self.assertEqual(texpr(expr), result(1, ""))
self.assertEqual(pexpr(expr), result(1, ""))
for expr, result in cpu_cases:
self.assertEqual(cexpr(expr), result(1.0, "L")) # 1.0 for FP div
self.assertEqual(
cexpr(expr),
result(1.0, "LL")
if sys.platform in ["darwin", "win32"]
else result(1.0, "L"),
) # 1.0 for FP div

def test_print_floor(self):
for integer in [True, False]:
Expand All @@ -278,7 +284,7 @@ def test_print_floor(self):
if integer:
self.assertEqual(pexpr(expr), "math.floor((1/2)*s1)")
self.assertEqual(
cexpr(expr), "static_cast<long>(std::floor((1.0/2.0)*s1))"
cexpr(expr), "static_cast<int64_t>(std::floor((1.0/2.0)*s1))"
)
else:
self.assertExpectedInline(pexpr(expr), """math.floor((1/2)*s1)""")
Expand All @@ -295,7 +301,7 @@ def test_print_ceil(self):
if integer:
self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""")
self.assertExpectedInline(
cexpr(expr), """static_cast<long>(std::ceil((1.0/2.0)*s1))"""
cexpr(expr), """static_cast<int64_t>(std::ceil((1.0/2.0)*s1))"""
)
else:
self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""")
Expand Down Expand Up @@ -325,13 +331,19 @@ def test_print_floor_div(self):
s2 = sympy.Symbol("s2", integer=True)
expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(s1 // s2)")
self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)")
self.assertEqual(
cexpr(expr),
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",
)

s1 = sympy.Symbol("s1", integer=True)
s2 = sympy.S(-1)
expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(-1)*s1")
self.assertEqual(cexpr(expr), "(-1L)*s1")
self.assertEqual(cexpr(expr), "(-1LL)*s1") if sys.platform in [
"darwin",
"win32",
] else "(-1L)*s1"

def test_print_Min_Max(self):
cases = (
Expand All @@ -344,14 +356,24 @@ def test_print_Min_Max(self):
self.assertEqual(
texpr(expr), f"((-2) * ((-2) {cmp}= (x)) + (x) * ((x) {cmp} (-2)))"
)
self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)")
self.assertEqual(
cexpr(expr),
f"std::{s}(static_cast<int64_t>(-2LL), static_cast<int64_t>(x))"
if sys.platform in ["darwin", "win32"]
else f"std::{s}(static_cast<int64_t>(-2L), static_cast<int64_t>(x))",
)

expr = f(x, 2 * x, 3 * x)
self.assertEqual(
texpr(expr),
f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))", # noqa: B950 line too long
)
self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})")
self.assertEqual(
cexpr(expr),
f"std::{s}({{x, 2LL*x, 3LL*x}})"
if sys.platform in ["darwin", "win32"]
else f"std::{s}({{x, 2L*x, 3L*x}})",
)


instantiate_parametrized_tests(ExprPrinterTests)
Expand Down
4 changes: 2 additions & 2 deletions test/inductor/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def test_cpp_wrapper(self):
result, code = run_and_get_cpp_code(compiled, *args)

FileCheck().check(
"pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast<long>(s0*s0)))), }, {1L, }"
"pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast<int64_t>(s0*s0)))), }, {1L, }"
).check_next(
"auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});"
).check(
"auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast<long>(s0*s0))),"
"auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast<int64_t>(s0*s0))),"
).run(
code
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,7 +2227,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
}
template <> inline long parse_arg<long>(PyObject* args, size_t n) {
template <> inline int64_t parse_arg<int64_t>(PyObject* args, size_t n) {
auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1 && PyErr_Occurred()))
throw std::runtime_error("expected int arg");
Expand Down
14 changes: 8 additions & 6 deletions torch/_inductor/codegen/cpp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

_IS_WINDOWS = sys.platform == "win32"

INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long"
INDEX_TYPE = "int64_t"

GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])

Expand Down Expand Up @@ -222,7 +222,9 @@ def depends_on(self, itervar: sympy.Symbol):

class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return f"{int(expr)}LL" if _IS_WINDOWS else f"{int(expr)}L"
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)

def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
Expand All @@ -236,7 +238,7 @@ def _print_ModularIndexing(self, expr):
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer({x}, {div})"
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
Expand All @@ -247,7 +249,7 @@ def _print_FloorDiv(self, expr):
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer({x}, {div})"
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"

def _print_floor(self, expr):
Expand Down Expand Up @@ -345,7 +347,7 @@ def _print_CeilToInt(self, expr):
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min({args[0]}, {args[1]})"
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
Expand All @@ -354,7 +356,7 @@ def _print_Min(self, expr):
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max({args[0]}, {args[1]})"
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
Expand Down

0 comments on commit fbf3fc2

Please sign in to comment.