forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cinn(py-dsl): add runtime module to python dsl (PaddlePaddle#58009)
拆分新特性:CINN Python DSL, 主PR和单测见:PaddlePaddle#56393 此PR只负责 给python dsl封装cinn ir的Runtime
- Loading branch information
Showing
9 changed files
with
480 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import cinn | ||
from cinn import framework | ||
from cinn.backends import Compiler | ||
|
||
|
||
class Module: | ||
def __init__(self, llir_module, target, fn_name, arg_names): | ||
self.arg_names = arg_names | ||
self.fn_name = fn_name | ||
self.compiler = Compiler.create(target) | ||
self.compiler.build(llir_module) | ||
self._instruction = framework.Instruction( | ||
target, None, [], arg_names, fn_name | ||
) | ||
|
||
def __call__(self, *args): | ||
name2pod = {} | ||
for i, name in enumerate(self.arg_names): | ||
if isinstance(args[i], cinn.runtime.data_array.DataArray): | ||
name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i].data) | ||
else: | ||
name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i]) | ||
|
||
self._instruction.run(self.compiler, self.fn_name, name2pod) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from test.cinn.utils.testing import assert_llir_equal | ||
|
||
from cinn import ir, to_cinn_llir | ||
from cinn.runtime.data_array import DataArray | ||
from cinn.schedule import IRSchedule as sch | ||
|
||
|
||
def test_bind_reduce(): | ||
@to_cinn_llir | ||
def reduce_sum(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))): | ||
for i1 in range(1): | ||
for j1 in range(4): | ||
for k1 in range(256): | ||
with ir.ScheduleBlockContext("init") as init: | ||
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) | ||
B[vi, vj, vk] = 0.0 | ||
for l1 in range(512): | ||
with ir.ScheduleBlockContext("B"): | ||
sch.bind(i1, "blockIdx.x") | ||
sch.bind(j1, "threadIdx.y") | ||
sch.bind(k1, "threadIdx.x") | ||
vi1, vj1, vk1, vl1 = ir.AxisMap( | ||
"SSSR", [i1, j1, k1, l1] | ||
) | ||
B[vi1, vj1, vk1] = ( | ||
B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] | ||
) | ||
|
||
@to_cinn_llir | ||
def reduce_sum_expected( | ||
A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256)) | ||
): | ||
for i1 in range(1): | ||
for j1 in range(4): | ||
for k1 in range(256): | ||
with ir.ScheduleBlockContext("init") as init: | ||
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) | ||
B[vi, vj, vk] = 0.0 | ||
for l1 in range(512): | ||
with ir.ScheduleBlockContext("B"): | ||
vi1, vj1, vk1, vl1 = ir.AxisMap( | ||
"SSSR", [i1, j1, k1, l1] | ||
) | ||
B[vi1, vj1, vk1] = ( | ||
B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] | ||
) | ||
sch.bind(init.i1, "blockIdx.x") | ||
sch.bind(init.j1, "threadIdx.y") | ||
sch.bind(init.k1, "threadIdx.x") | ||
|
||
assert_llir_equal(reduce_sum, reduce_sum_expected) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_bind_reduce() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from test.cinn.utils.testing import assert_llir_equal | ||
|
||
from cinn import ir, to_cinn_llir | ||
from cinn.runtime.data_array import DataArray | ||
from cinn.schedule import IRSchedule as sch | ||
|
||
|
||
# Current Python DSL cannot express the parallel `for`, | ||
# only checks that it can be converted correctly | ||
def test_elementwise_parallel(): | ||
@to_cinn_llir | ||
def elementwise_add( | ||
X: DataArray((128, 128)), | ||
Y: DataArray((128, 128)), | ||
A: DataArray((128, 128)), | ||
): | ||
for i in range(128): | ||
for j in range(128): | ||
with ir.ScheduleBlockContext("A") as A_block: | ||
i1, j1 = ir.AxisMap("SS", [i, j]) | ||
A[i1, j1] = X[i1, j1] * 2.0 | ||
for i in range(128): | ||
for j in range(128): | ||
with ir.ScheduleBlockContext("Y"): | ||
i1, j1 = ir.AxisMap("SS", [i, j]) | ||
Y[i1, j1] = A[i1, j1] + 2.0 | ||
sch.parallel(A_block.i) | ||
|
||
assert_llir_equal(elementwise_add, elementwise_add) | ||
|
||
|
||
# Current Python DSL cannot express the vectorize `for`, | ||
# only checks that it can be converted correctly | ||
def test_elementwise_vectorize(): | ||
@to_cinn_llir | ||
def elementwise_add( | ||
X: DataArray((128, 128)), | ||
Y: DataArray((128, 128)), | ||
A: DataArray((128, 128)), | ||
): | ||
for i in range(128): | ||
for j in range(128): | ||
with ir.ScheduleBlockContext("A") as A_block: | ||
i1, j1 = ir.AxisMap("SS", [i, j]) | ||
A[i1, j1] = X[i1, j1] * 2.0 | ||
for i in range(128): | ||
for j0 in range(32): | ||
for j1 in range(4): | ||
with ir.ScheduleBlockContext("Y") as Y_block: | ||
i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1]) | ||
Y[i1, j1] = A[i1, j1] + 2.0 | ||
sch.vectorize(Y_block.j1, 1) | ||
|
||
assert_llir_equal(elementwise_add, elementwise_add) | ||
|
||
|
||
# Current Python DSL cannot express the unroll `for`, | ||
# only checks that it can be converted correctly | ||
def test_elementwise_unroll(): | ||
@to_cinn_llir | ||
def elementwise_add( | ||
X: DataArray((128, 128)), | ||
Y: DataArray((128, 128)), | ||
A: DataArray((128, 128)), | ||
): | ||
for i in range(128): | ||
for j in range(128): | ||
with ir.ScheduleBlockContext("A") as A_block: | ||
i1, j1 = ir.AxisMap("SS", [i, j]) | ||
A[i1, j1] = X[i1, j1] * 2.0 | ||
for i in range(128): | ||
for j0 in range(32): | ||
for j1 in range(4): | ||
with ir.ScheduleBlockContext("Y") as Y_block: | ||
i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1]) | ||
Y[i1, j1] = A[i1, j1] + 2.0 | ||
sch.unroll(Y_block.j1) | ||
|
||
assert_llir_equal(elementwise_add, elementwise_add) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_elementwise_parallel() | ||
test_elementwise_vectorize() | ||
test_elementwise_unroll() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from cinn import ir, to_cinn_llir | ||
from cinn.runtime.data_array import DataArray | ||
from cinn.schedule import IRSchedule as sch | ||
|
||
|
||
def test_matmul(): | ||
@to_cinn_llir | ||
def matmul( | ||
A: DataArray((128, 128)), | ||
B: DataArray((128, 128)), | ||
C: DataArray((128, 128)), | ||
): | ||
for i0 in range(128): | ||
for i1 in range(128): | ||
with ir.ScheduleBlockContext("init"): | ||
vi, vj = ir.AxisMap("SS", [i0, i1]) | ||
C[vi, vj] = 0.0 | ||
for i2_outer in range(4): | ||
for i2_inner_outer in range(8): | ||
for i2_inner_inner in range(4): | ||
with ir.ScheduleBlockContext( | ||
"compute" | ||
) as Compute_block: | ||
vi, vj, vk = ir.AxisMap( | ||
"SSR", | ||
[ | ||
i0, | ||
i1, | ||
i2_outer * 32 | ||
+ i2_inner_outer * 4 | ||
+ i2_inner_inner, | ||
], | ||
) | ||
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) | ||
sch.rfactor(Compute_block.i2_inner_inner, 0) | ||
|
||
# TODO(6clc): rfactor schedule rasie Error Message: iter_value not support complex reduce bindings | ||
# assert_llir_equal(matmul, matmul) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_matmul() |
Oops, something went wrong.