Skip to content

Commit

Permalink
cinn(py-dsl): add runtime module to python dsl (PaddlePaddle#58009)
Browse files Browse the repository at this point in the history
拆分新特性:CINN Python DSL, 主PR和单测见:PaddlePaddle#56393

此PR只负责 给python dsl封装cinn ir的Runtime
  • Loading branch information
6clc authored and jiahy0825 committed Oct 26, 2023
1 parent 7c80e33 commit d3659e5
Show file tree
Hide file tree
Showing 9 changed files with 480 additions and 25 deletions.
14 changes: 14 additions & 0 deletions python/cinn/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cinn

from ..runtime import CinnLowerLevelIrJit
from .compute_code_generator import ComputeCodeGenerator
Expand All @@ -31,6 +32,13 @@ def ast_to_llir(fn, inputs_signature):
return llir_schedule_generator.parse()


def llir_to_runtime_module(llir_func, target, function_name, arg_names):
cinn_builder = cinn.lang.Module.Builder(function_name, target)
cinn_builder.add_function(llir_func)
llir_module = cinn_builder.build()
return cinn.runtime.Module(llir_module, target, function_name, arg_names)


def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):
if isinstance(fn, CinnLowerLevelIrJit):
llir_func = ast_to_llir(fn, jit_inputs_signature)
Expand All @@ -39,3 +47,9 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):

if just_convert:
return llir_func

rt_module = llir_to_runtime_module(
llir_func, kwargs["target"], fn.__name__, kwargs["arg_names"]
)

return rt_module
3 changes: 2 additions & 1 deletion python/cinn/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,6 @@
)

from .cinn_jit import CinnLowerLevelIrJit
from .module import Module

__all__ = ["CinnLowerLevelIrJit"]
__all__ = ["CinnLowerLevelIrJit", "Module"]
55 changes: 31 additions & 24 deletions python/cinn/runtime/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,39 @@ def to_numpy(self):
"""
Convert DataArray to numpy array
"""
cinn_dtype_to_np_dtype = {
np_dtype = "unk"
if self.dtype.is_bfloat16():
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
BFloat16(): "uint16",
BFloat16(): "bfloat16",
Float16(): "float16",
Float(32): "float32",
Float(64): "float64",
Int(8): "int8",
Int(16): "int16",
Int(32): "int32",
Int(64): "int64",
UInt(8): "uint8",
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
# "UInt(16): uint16"
UInt(32): "uint32",
UInt(64): "uint64",
Bool(): "bool",
}
for cinn_dtype, np_dtype in cinn_dtype_to_np_dtype.items():
if isinstance(self.dtype, cinn_dtype):
np_arr = np.empty(self.shape, np_dtype)
assert np_arr.flags["C_CONTIGUOUS"]
self.data.copy_to(np_arr)
return np_arr
np_dtype = "uint16"
elif self.dtype.is_float16():
np_dtype = "float16"
elif self.dtype.is_float(32, common.Type.specific_type_t.UNK):
np_dtype = "float32"
elif self.dtype.is_float(64, common.Type.specific_type_t.UNK):
np_dtype = "float64"
elif self.dtype.is_int(8):
np_dtype = "int8"
elif self.dtype.is_int(16):
np_dtype = "int16"
elif self.dtype.is_int(32):
np_dtype = "int32"
elif self.dtype.is_int(64):
np_dtype = "int64"
elif self.dtype.is_uint(8):
np_dtype = "uint8"
elif self.dtype.is_uint(32):
np_dtype = "uint32"
elif self.dtype.is_uint(64):
np_dtype = "uint64"
elif self.dtype.is_bool():
np_dtype = "bool"
else:
raise TypeError(f"no support {self.dtype} in CINN")

raise TypeError(f"no support {self._dtype} in CINN")
np_arr = np.empty(self.shape, np_dtype)
assert np_arr.flags["C_CONTIGUOUS"]
self.data.copy_to(np_arr)
return np_arr

@staticmethod
def from_numpy(np_array, target=common.DefaultHostTarget()):
Expand Down
37 changes: 37 additions & 0 deletions python/cinn/runtime/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cinn
from cinn import framework
from cinn.backends import Compiler


class Module:
def __init__(self, llir_module, target, fn_name, arg_names):
self.arg_names = arg_names
self.fn_name = fn_name
self.compiler = Compiler.create(target)
self.compiler.build(llir_module)
self._instruction = framework.Instruction(
target, None, [], arg_names, fn_name
)

def __call__(self, *args):
name2pod = {}
for i, name in enumerate(self.arg_names):
if isinstance(args[i], cinn.runtime.data_array.DataArray):
name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i].data)
else:
name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i])

self._instruction.run(self.compiler, self.fn_name, name2pod)
68 changes: 68 additions & 0 deletions test/cinn/ir/test_llir_schedule_bind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test.cinn.utils.testing import assert_llir_equal

from cinn import ir, to_cinn_llir
from cinn.runtime.data_array import DataArray
from cinn.schedule import IRSchedule as sch


def test_bind_reduce():
@to_cinn_llir
def reduce_sum(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))):
for i1 in range(1):
for j1 in range(4):
for k1 in range(256):
with ir.ScheduleBlockContext("init") as init:
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1])
B[vi, vj, vk] = 0.0
for l1 in range(512):
with ir.ScheduleBlockContext("B"):
sch.bind(i1, "blockIdx.x")
sch.bind(j1, "threadIdx.y")
sch.bind(k1, "threadIdx.x")
vi1, vj1, vk1, vl1 = ir.AxisMap(
"SSSR", [i1, j1, k1, l1]
)
B[vi1, vj1, vk1] = (
B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1]
)

@to_cinn_llir
def reduce_sum_expected(
A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))
):
for i1 in range(1):
for j1 in range(4):
for k1 in range(256):
with ir.ScheduleBlockContext("init") as init:
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1])
B[vi, vj, vk] = 0.0
for l1 in range(512):
with ir.ScheduleBlockContext("B"):
vi1, vj1, vk1, vl1 = ir.AxisMap(
"SSSR", [i1, j1, k1, l1]
)
B[vi1, vj1, vk1] = (
B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1]
)
sch.bind(init.i1, "blockIdx.x")
sch.bind(init.j1, "threadIdx.y")
sch.bind(init.k1, "threadIdx.x")

assert_llir_equal(reduce_sum, reduce_sum_expected)


if __name__ == "__main__":
test_bind_reduce()
99 changes: 99 additions & 0 deletions test/cinn/ir/test_llir_schedule_for_kind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from test.cinn.utils.testing import assert_llir_equal

from cinn import ir, to_cinn_llir
from cinn.runtime.data_array import DataArray
from cinn.schedule import IRSchedule as sch


# Current Python DSL cannot express the parallel `for`,
# only checks that it can be converted correctly
def test_elementwise_parallel():
@to_cinn_llir
def elementwise_add(
X: DataArray((128, 128)),
Y: DataArray((128, 128)),
A: DataArray((128, 128)),
):
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("A") as A_block:
i1, j1 = ir.AxisMap("SS", [i, j])
A[i1, j1] = X[i1, j1] * 2.0
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("Y"):
i1, j1 = ir.AxisMap("SS", [i, j])
Y[i1, j1] = A[i1, j1] + 2.0
sch.parallel(A_block.i)

assert_llir_equal(elementwise_add, elementwise_add)


# Current Python DSL cannot express the vectorize `for`,
# only checks that it can be converted correctly
def test_elementwise_vectorize():
@to_cinn_llir
def elementwise_add(
X: DataArray((128, 128)),
Y: DataArray((128, 128)),
A: DataArray((128, 128)),
):
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("A") as A_block:
i1, j1 = ir.AxisMap("SS", [i, j])
A[i1, j1] = X[i1, j1] * 2.0
for i in range(128):
for j0 in range(32):
for j1 in range(4):
with ir.ScheduleBlockContext("Y") as Y_block:
i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1])
Y[i1, j1] = A[i1, j1] + 2.0
sch.vectorize(Y_block.j1, 1)

assert_llir_equal(elementwise_add, elementwise_add)


# Current Python DSL cannot express the unroll `for`,
# only checks that it can be converted correctly
def test_elementwise_unroll():
@to_cinn_llir
def elementwise_add(
X: DataArray((128, 128)),
Y: DataArray((128, 128)),
A: DataArray((128, 128)),
):
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("A") as A_block:
i1, j1 = ir.AxisMap("SS", [i, j])
A[i1, j1] = X[i1, j1] * 2.0
for i in range(128):
for j0 in range(32):
for j1 in range(4):
with ir.ScheduleBlockContext("Y") as Y_block:
i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1])
Y[i1, j1] = A[i1, j1] + 2.0
sch.unroll(Y_block.j1)

assert_llir_equal(elementwise_add, elementwise_add)


if __name__ == "__main__":
test_elementwise_parallel()
test_elementwise_vectorize()
test_elementwise_unroll()
57 changes: 57 additions & 0 deletions test/cinn/ir/test_llir_schedule_rfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from cinn import ir, to_cinn_llir
from cinn.runtime.data_array import DataArray
from cinn.schedule import IRSchedule as sch


def test_matmul():
@to_cinn_llir
def matmul(
A: DataArray((128, 128)),
B: DataArray((128, 128)),
C: DataArray((128, 128)),
):
for i0 in range(128):
for i1 in range(128):
with ir.ScheduleBlockContext("init"):
vi, vj = ir.AxisMap("SS", [i0, i1])
C[vi, vj] = 0.0
for i2_outer in range(4):
for i2_inner_outer in range(8):
for i2_inner_inner in range(4):
with ir.ScheduleBlockContext(
"compute"
) as Compute_block:
vi, vj, vk = ir.AxisMap(
"SSR",
[
i0,
i1,
i2_outer * 32
+ i2_inner_outer * 4
+ i2_inner_inner,
],
)
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
sch.rfactor(Compute_block.i2_inner_inner, 0)

# TODO(6clc): rfactor schedule rasie Error Message: iter_value not support complex reduce bindings
# assert_llir_equal(matmul, matmul)


if __name__ == "__main__":
test_matmul()
Loading

0 comments on commit d3659e5

Please sign in to comment.