From 0f87e8c6d33fe7c55839643566dcc8dd73e1fab5 Mon Sep 17 00:00:00 2001 From: 6clc Date: Fri, 22 Sep 2023 11:03:23 +0800 Subject: [PATCH] all to one --- paddle/cinn/pybind/CMakeLists.txt | 3 +- paddle/cinn/pybind/bind.cc | 3 + paddle/cinn/pybind/bind.h | 1 + paddle/cinn/pybind/framework.cc | 19 ++ paddle/cinn/pybind/ir/ir_context.h | 4 +- paddle/cinn/pybind/lang.cc | 16 +- paddle/cinn/pybind/runtime.cc | 63 ++++- paddle/cinn/pybind/schedule.cc | 151 +++++++++++ paddle/cinn/pybind/utils.cc | 5 + python/cinn/__init__.py | 3 +- python/cinn/compiler/__init__.py | 17 ++ python/cinn/compiler/compiler.py | 55 ++++ .../cinn/compiler/compute_code_generator.py | 251 ++++++++++++++++++ python/cinn/compiler/expr_executor.py | 158 +++++++++++ .../cinn/compiler/schedule_code_generator.py | 178 +++++++++++++ python/cinn/compiler/utils.py | 75 ++++++ python/cinn/framework.py | 1 + python/cinn/ir/__init__.py | 20 ++ python/cinn/ir/ir.py | 28 ++ python/cinn/ir/ir_context.py | 81 ++++++ .../cinn/{runtime.py => runtime/__init__.py} | 9 +- python/cinn/runtime/cinn_jit.py | 115 ++++++++ python/cinn/runtime/data_array.py | 61 +++++ python/cinn/runtime/module.py | 37 +++ python/cinn/runtime/utils.py | 35 +++ python/cinn/schedule.py | 15 ++ python/setup_cinn.py.in | 3 +- test/cinn/CMakeLists.txt | 36 +++ test/cinn/ir/test_llir_schedule_bind.py | 46 ++++ .../ir/test_llir_schedule_cache_read_write.py | 73 +++++ test/cinn/ir/test_llir_schedule_compute_at.py | 111 ++++++++ .../ir/test_llir_schedule_compute_inline.py | 95 +++++++ test/cinn/ir/test_llir_schedule_for_kind.py | 94 +++++++ test/cinn/ir/test_llir_schedule_fuse_split.py | 132 +++++++++ test/cinn/ir/test_llir_schedule_reorder.py | 80 ++++++ test/cinn/ir/test_llir_schedule_rfactor.py | 58 ++++ test/cinn/ir/test_llir_schedule_sequence.py | 70 +++++ test/cinn/runtime/test_launch.py | 55 ++++ test/cinn/runtime/test_reduce_cuda.py | 59 ++++ test/cinn/utils/testing.py | 28 ++ 40 files changed, 2334 insertions(+), 10 deletions(-) create mode 100644 paddle/cinn/pybind/schedule.cc create mode 100644 python/cinn/compiler/__init__.py create mode 100644 python/cinn/compiler/compiler.py create mode 100644 python/cinn/compiler/compute_code_generator.py create mode 100644 python/cinn/compiler/expr_executor.py create mode 100644 python/cinn/compiler/schedule_code_generator.py create mode 100644 python/cinn/compiler/utils.py create mode 100644 python/cinn/ir/ir.py create mode 100644 python/cinn/ir/ir_context.py rename python/cinn/{runtime.py => runtime/__init__.py} (88%) create mode 100644 python/cinn/runtime/cinn_jit.py create mode 100644 python/cinn/runtime/data_array.py create mode 100644 python/cinn/runtime/module.py create mode 100644 python/cinn/runtime/utils.py create mode 100644 python/cinn/schedule.py create mode 100644 test/cinn/ir/test_llir_schedule_bind.py create mode 100644 test/cinn/ir/test_llir_schedule_cache_read_write.py create mode 100644 test/cinn/ir/test_llir_schedule_compute_at.py create mode 100644 test/cinn/ir/test_llir_schedule_compute_inline.py create mode 100644 test/cinn/ir/test_llir_schedule_for_kind.py create mode 100644 test/cinn/ir/test_llir_schedule_fuse_split.py create mode 100644 test/cinn/ir/test_llir_schedule_reorder.py create mode 100644 test/cinn/ir/test_llir_schedule_rfactor.py create mode 100644 test/cinn/ir/test_llir_schedule_sequence.py create mode 100644 test/cinn/runtime/test_launch.py create mode 100644 test/cinn/runtime/test_reduce_cuda.py create mode 100644 test/cinn/utils/testing.py diff --git a/paddle/cinn/pybind/CMakeLists.txt b/paddle/cinn/pybind/CMakeLists.txt index c00a64614f6430..33dc27860f9473 100755 --- a/paddle/cinn/pybind/CMakeLists.txt +++ b/paddle/cinn/pybind/CMakeLists.txt @@ -12,7 +12,8 @@ set(srcs pe.cc frontend.cc framework.cc - utils.cc) + utils.cc + schedule.cc) if(WITH_CUDA) message(STATUS "Compile core_api with CUDA support") diff --git a/paddle/cinn/pybind/bind.cc b/paddle/cinn/pybind/bind.cc index bf1285957e245d..4c20f22b973cfe 100644 --- a/paddle/cinn/pybind/bind.cc +++ b/paddle/cinn/pybind/bind.cc @@ -41,6 +41,8 @@ PYBIND11_MODULE(core_api, m) { "framework", "namespace cinn::hlir::framework, CINN framework"); py::module utils = m.def_submodule("utils", "namespace cinn::utils, CINN framework"); + py::module schedule = m.def_submodule( + "schedule", "namespace cinn::ir::schedule, CINN Schedule"); BindRuntime(&runtime); BindCommon(&common); @@ -53,6 +55,7 @@ PYBIND11_MODULE(core_api, m) { BindFrontend(&frontend); BindFramework(&framework); BindUtils(&utils); + BindSchedule(&schedule); } } // namespace cinn::pybind diff --git a/paddle/cinn/pybind/bind.h b/paddle/cinn/pybind/bind.h index cb56cae0096cfd..63e5f8a074e552 100644 --- a/paddle/cinn/pybind/bind.h +++ b/paddle/cinn/pybind/bind.h @@ -52,5 +52,6 @@ void BindPE(pybind11::module *m); void BindFrontend(pybind11::module *m); void BindFramework(pybind11::module *m); void BindUtils(pybind11::module *m); +void BindSchedule(pybind11::module *m); } // namespace cinn::pybind diff --git a/paddle/cinn/pybind/framework.cc b/paddle/cinn/pybind/framework.cc index 0fbfe16dfe0cfa..752ac5003f43a2 100644 --- a/paddle/cinn/pybind/framework.cc +++ b/paddle/cinn/pybind/framework.cc @@ -20,6 +20,7 @@ #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/frontend/interpreter.h" #include "paddle/cinn/hlir/framework/graph_compiler.h" +#include "paddle/cinn/hlir/framework/instruction.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op_strategy.h" @@ -211,5 +212,23 @@ void BindFramework(pybind11::module *m) { CINN_NOT_IMPLEMENTED } }); + + py::class_ instruction(*m, "Instruction"); + instruction + .def(py::init &, + const std::vector &, + const std::string &>()) + .def("run", + [](Instruction &self, + backends::Compiler &compiler, + const std::string fn_name, + std::map &name_to_pod) { + auto fn_ptr = compiler.Lookup(fn_name); + self.Finalize(); + self.SetLoweredFunc(fn_ptr); + self.Run(&name_to_pod); + }); } } // namespace cinn::pybind diff --git a/paddle/cinn/pybind/ir/ir_context.h b/paddle/cinn/pybind/ir/ir_context.h index c96c423bb071e0..89b65512e26664 100644 --- a/paddle/cinn/pybind/ir/ir_context.h +++ b/paddle/cinn/pybind/ir/ir_context.h @@ -45,7 +45,7 @@ class IRContextNode : public common::Object { }; /** - * The lifecycle of RAII resource management for IRContextNode + * The life cycle of RAII resource management for IRContextNode * is determined at the Python. */ class IRContext { @@ -215,7 +215,7 @@ class IRBuilderNode : public common::Object { }; /** - * The lifecycle of RAII resource management for IRBuilderNode + * The life cycle of RAII resource management for IRBuilderNode * is determined at the Python. */ class IRBuilder { diff --git a/paddle/cinn/pybind/lang.cc b/paddle/cinn/pybind/lang.cc index 3ff7b4d318e5fe..ce0a480221643f 100644 --- a/paddle/cinn/pybind/lang.cc +++ b/paddle/cinn/pybind/lang.cc @@ -20,12 +20,15 @@ #include "paddle/cinn/backends/codegen_c.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/ir/module.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/buffer.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/pybind/bind_utils.h" @@ -148,7 +151,18 @@ void BindModule(py::module *m) { py::class_ builder(module, "Builder"); builder.def(py::init()) - .def("add_function", &ir::Module::Builder::AddFunction) + .def("add_function", + [](ir::Module::Builder &self, ir::LoweredFunc func) { + // TODO(6clc): optimize by register of backend passs + if (self.GetTargetArch() == Target::Arch::NVGPU) { +#ifdef CINN_WITH_CUDA + auto func_expr = Expr(func); + ir::SetCudaAxisInfo(&func_expr); + optim::OptimizeExprGPU(&(func->body)); +#endif + } + self.AddFunction(func); + }) .def("add_buffer", &ir::Module::Builder::AddBuffer) .def("build", &ir::Module::Builder::Build); } diff --git a/paddle/cinn/pybind/runtime.cc b/paddle/cinn/pybind/runtime.cc index c69513a250384f..906e6846c1c68e 100644 --- a/paddle/cinn/pybind/runtime.cc +++ b/paddle/cinn/pybind/runtime.cc @@ -21,10 +21,18 @@ #include #include +#include "paddle/cinn/common/common.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/flags.h" +#ifdef CINN_WITH_CUDA +#include +#include + +#include "paddle/cinn/backends/cuda_util.h" +#endif + namespace py = pybind11; namespace cinn::pybind { namespace { @@ -66,6 +74,49 @@ cinn_buffer_t *CreateBufferFromNumpy(py::array data, return buffer; } +// TODO(6clc): Implement device of gpu like cinn_x86_device_impl.cc in runtime. +cinn_buffer_t *CreateBufferFromNumpy( + py::array data, + common::Target target = common::DefaultHostTarget(), + int align = 0) { + if (target == common::DefaultHostTarget()) { + return CreateBufferFromNumpy(data, cinn_x86_device); + } else if (target.arch == Target::Arch::NVGPU) { +#ifdef CINN_WITH_CUDA + std::vector shape; + std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape)); + auto *buffer = new cinn_buffer_t(); + buffer->device = cinn_nvgpu_device; + buffer->memory_size = data.nbytes(); + CUDA_CALL(cudaMalloc(&buffer->memory, data.nbytes())); + CUDA_CALL(cudaMemcpy( + buffer->memory, data.data(), data.nbytes(), cudaMemcpyHostToDevice)); + return buffer; +#else + LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; +#endif + } else { + CINN_NOT_IMPLEMENTED + } +} + +void BufferCopyTo(const cinn_buffer_t &buffer, py::array array) { + void *array_data = array.mutable_data(); + if (buffer.device == cinn_x86_device) { + std::memcpy(array_data, buffer.memory, array.nbytes()); + } else if (buffer.device == cinn_nvgpu_device) { +#ifdef CINN_WITH_CUDA + CUDA_CALL(cudaMemcpy( + array_data, buffer.memory, array.nbytes(), cudaMemcpyDeviceToHost)); +#else + LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; +#endif + + } else { + CINN_NOT_IMPLEMENTED + } +} + py::array BufferHostMemoryToNumpy(cinn_buffer_t &buffer) { // NOLINT py::dtype dt; if (buffer.type == cinn_int32_t()) { @@ -162,6 +213,7 @@ void BindCinnRuntime(py::module *m) { .value("cinn_x86_device", cinn_x86_device) .value("cinn_opencl_device", cinn_opencl_device) .value("cinn_arm_device", cinn_arm_device) + .value("cinn_nvgpu_device", cinn_nvgpu_device) .export_values(); py::enum_ cinn_buffer_kind(*m, "cinn_buffer_kind_t"); @@ -220,10 +272,17 @@ void BindCinnRuntime(py::module *m) { .def("set_flag", &cinn_buffer_t::set_flag) // Python methods .def("numpy", &BufferHostMemoryToNumpy) - .def(py::init(&CreateBufferFromNumpy), + .def(py::init(py::overload_cast( + &CreateBufferFromNumpy)), arg("data"), arg("device"), - arg("align") = 0); + arg("align") = 0) + .def(py::init(py::overload_cast( + &CreateBufferFromNumpy)), + arg("data"), + arg("target"), + arg("align") = 0) + .def("copy_to", &BufferCopyTo); m->def("cinn_x86_device_interface", &cinn_x86_device_interface) .def("cinn_buffer_load_float32", &cinn_buffer_load_float32) diff --git a/paddle/cinn/pybind/schedule.cc b/paddle/cinn/pybind/schedule.cc new file mode 100644 index 00000000000000..6ce57fd3ce4ceb --- /dev/null +++ b/paddle/cinn/pybind/schedule.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include + +#include "paddle/cinn/ir/schedule/ir_schedule.h" + +namespace py = pybind11; + +namespace cinn::pybind { + +void BindSchedule(py::module *m) { + py::class_ ir_schedule(*m, "IRSchedule"); + ir_schedule + .def(py::init(), + py::arg("modexpr"), + py::arg("rand_seed") = -1, + py::arg("debug_flag") = false, + py::arg("err_msg_level") = utils::ErrorMessageLevel::kGeneral) + .def_static( + "make", + [](ir::LoweredFunc &ir_func) { + ir::ModuleExpr *module_expr = new ir::ModuleExpr({ir_func->body}); + auto scheduler = std::make_unique(*module_expr); + return scheduler; + }) + .def("fuse", + py::overload_cast &>(&ir::IRSchedule::Fuse)) + .def("split", + py::overload_cast &>( + &ir::IRSchedule::Split), + py::arg("loop"), + py::arg("factors")) + .def("compute_at", + py::overload_cast( + &ir::IRSchedule::ComputeAt), + py::arg("block"), + py::arg("loop"), + py::arg("keep_unit_loops") = false) + .def("simple_compute_at", + py::overload_cast( + &ir::IRSchedule::SimpleComputeAt), + py::arg("block"), + py::arg("loop")) + .def("reverse_compute_at", + py::overload_cast( + &ir::IRSchedule::ReverseComputeAt), + py::arg("block"), + py::arg("loop"), + py::arg("keep_unit_loops") = false) + .def("cache_read", + py::overload_cast( + &ir::IRSchedule::CacheRead)) + .def("cache_write", + py::overload_cast( + &ir::IRSchedule::CacheWrite)) + .def("sync_threads", + py::overload_cast(&ir::IRSchedule::SyncThreads), + py::arg("ir_node"), + py::arg("after_node") = true) + .def("set_buffer", + py::overload_cast( + &ir::IRSchedule::SetBuffer), + py::arg("block"), + py::arg("memory_type"), + py::arg("fixed") = false) + .def("reorder", + py::overload_cast &>( + &ir::IRSchedule::Reorder)) + .def("parallel", + py::overload_cast(&ir::IRSchedule::Parallel)) + .def("vectorize", + py::overload_cast(&ir::IRSchedule::Vectorize)) + .def("unroll", py::overload_cast(&ir::IRSchedule::Unroll)) + + .def("compute_inline", + py::overload_cast(&ir::IRSchedule::ComputeInline)) + .def("reverse_compute_inline", + py::overload_cast( + &ir::IRSchedule::ReverseComputeInline)) + .def("bind", &ir::IRSchedule::Bind) + .def("copy_transform_and_loop_info", + py::overload_cast( + &ir::IRSchedule::CopyTransformAndLoopInfo)) + .def("rfactor", + py::overload_cast(&ir::IRSchedule::Rfactor)) + .def("annotate", + py::overload_cast(&ir::IRSchedule::Annotate)) + .def("unannotate", + py::overload_cast( + &ir::IRSchedule::Unannotate)) + .def("flatten_loops", + py::overload_cast &, const bool>( + &ir::IRSchedule::FlattenLoops), + py::arg("loops"), + py::arg("force_flat") = false) + .def("sample_perfect_tile", + py::overload_cast &>( + &ir::IRSchedule::SamplePerfectTile), + py::arg("loop"), + py::arg("n"), + py::arg("max_innermost_factor"), + py::arg("decision") = std::vector()) + .def("sample_categorical", + py::overload_cast &, + const std::vector &, + const std::vector &>( + &ir::IRSchedule::SampleCategorical), + py::arg("candidates"), + py::arg("probs"), + py::arg("decision") = std::vector()) + .def("get_module", + py::overload_cast<>(&ir::IRSchedule::GetModule, py::const_)) + .def("get_root_block", &ir::IRSchedule::GetRootBlock) + .def("get_block", + py::overload_cast(&ir::IRSchedule::GetBlock, + py::const_)) + .def("get_all_blocks", + py::overload_cast<>(&ir::IRSchedule::GetAllBlocks, py::const_)) + .def("get_loops", + py::overload_cast(&ir::IRSchedule::GetLoops, + py::const_)) + .def("get_name2loops_dict", + [](const ir::IRSchedule &self, const std::string &block_name) { + std::vector loops = self.GetLoops(block_name); + std::map name2loops; + for (const ir::Expr &loop : loops) { + name2loops[loop.As()->loop_var->name] = loop; + } + return name2loops; + }); +} +} // namespace cinn::pybind diff --git a/paddle/cinn/pybind/utils.cc b/paddle/cinn/pybind/utils.cc index 7e49247e85a82d..1f48e79b4f31bb 100644 --- a/paddle/cinn/pybind/utils.cc +++ b/paddle/cinn/pybind/utils.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/cinn/pybind/bind.h" +#include "paddle/cinn/utils/error.h" #include "paddle/cinn/utils/profiler.h" +#include "paddle/cinn/utils/random_engine.h" namespace py = pybind11; @@ -69,6 +71,9 @@ void BindUtils(py::module *m) { "type", [](HostEvent &self) -> const EventType & { return self.type_; }, [](HostEvent &self, const EventType &v) { self.type_ = v; }); + + py::class_(*m, "LinearRandomEngine"); + py::class_(*m, "ErrorMessageLevel"); } } // namespace pybind diff --git a/python/cinn/__init__.py b/python/cinn/__init__.py index 9411b774e38360..55ab35e7e56242 100644 --- a/python/cinn/__init__.py +++ b/python/cinn/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .version import full_version as __version__ +from .runtime.cinn_jit import to_cinn_llir import os cinndir = os.path.dirname(os.path.abspath(__file__)) @@ -189,4 +191,3 @@ reduce_mul, reduce_sum, ) -from .version import full_version as __version__ diff --git a/python/cinn/compiler/__init__.py b/python/cinn/compiler/__init__.py new file mode 100644 index 00000000000000..644bf2d949ca4e --- /dev/null +++ b/python/cinn/compiler/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compiler import compile + +__all__ = ["compile"] diff --git a/python/cinn/compiler/compiler.py b/python/cinn/compiler/compiler.py new file mode 100644 index 00000000000000..064b97c31f243b --- /dev/null +++ b/python/cinn/compiler/compiler.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cinn + +from ..runtime import CinnLowerLevelIrJit +from .compute_code_generator import ComputeCodeGenerator +from .schedule_code_generator import ScheduleCodeGenerator + + +def ast_to_llir(fn, inputs_signature): + function_name = fn.__name__ + # 1. Parse CINN Compute + llir_compute_generator = ComputeCodeGenerator( + fn, function_name, inputs_signature + ) + cinn_llir_func = llir_compute_generator.parse() + + # 2. Parse CINN Schedule + llir_schedule_generator = ScheduleCodeGenerator(fn, cinn_llir_func) + return llir_schedule_generator.parse() + + +def llir_to_runtime_module(llir_func, target, function_name, arg_names): + cinn_builder = cinn.lang.Module.Builder(function_name, target) + cinn_builder.add_function(llir_func) + llir_module = cinn_builder.build() + return cinn.runtime.Module(llir_module, target, function_name, arg_names) + + +def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs): + if isinstance(fn, CinnLowerLevelIrJit): + llir_func = ast_to_llir(fn, jit_inputs_signature) + else: + raise Exception("Current Only support compile from CinnLowerLevelIrJit") + + if just_convert: + return llir_func + + rt_module = llir_to_runtime_module( + llir_func, kwargs["target"], fn.__name__, kwargs["arg_names"] + ) + + return rt_module diff --git a/python/cinn/compiler/compute_code_generator.py b/python/cinn/compiler/compute_code_generator.py new file mode 100644 index 00000000000000..8f35b5c0c0a930 --- /dev/null +++ b/python/cinn/compiler/compute_code_generator.py @@ -0,0 +1,251 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import contextlib + +from cinn import ir + +from .expr_executor import ExprExecutor, exec_assign +from .utils import VariableTable, node_is_schedule + + +class ComputeCodeGenerator(ast.NodeVisitor): + """ + Convert python ast to CINN Lower Level IR, + containing only the semantics of the compute part + """ + + def __init__(self, fn, function_name, inputs_signature): + self.fn = fn + self.function_name = function_name + self.inputs_signature = inputs_signature + self.cinn_llir_func = None + self.variables_table = VariableTable() + self.extra_scope = {"range": ir.sequential} + + def parse(self): + ast_node = self.fn.parse() + with ir.IRBuilder() as builder, self.variables_table: + for k, v in self.fn.scope.items(): + self.variables_table.add(k, v) + for k, v in self.extra_scope.items(): + self.variables_table.add(k, v) + self.visit(ast_node) + return builder.get() + + def visit_FunctionDef(self, node) -> None: + """ + Parse CINN Low Level IR FunctionDef. + + Args: + node(ast.FunctionDef): The ast FunctionDef Node + """ + with ir.LowerFuncContext(self.function_name) as func_ctx: + arg_names = self.visit(node.args) + + assert len(node.args.defaults) == 0, "Not support default args" + + # 1. Construct args of function + for i, arg_name in enumerate(arg_names): + # Obj of Argument is ir::Buffer + if hasattr(self.inputs_signature[i], "dtype"): + tensor_shape = [ + ir.Expr(dim) for dim in self.inputs_signature[i].shape + ] + # TODO(6clc): unify Tensor and Buffer + llir_value = ir._Buffer_.make( + arg_name, self.inputs_signature[i].dtype + ) + ir.Arg(arg_name, llir_value) + llir_value = ir._Tensor_.make( + arg_name, + self.inputs_signature[i].dtype, + tensor_shape, + tensor_shape, + ) + self.variables_table.add(arg_name, llir_value) + # Obj of Argument is ir::Var + else: + llir_value = ir.Var(arg_name) + ir.Arg(arg_name, llir_value) + llir_value = ir.Expr(llir_value) + self.variables_table.add(arg_name, llir_value) + + # 2. Construct body of function + body = self.visit_compound_statement(node.body) + + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.visit(stmt) + + def visit_arguments(self, node): + """ + Parse CINN Low Level IR Argument. + If it is not jit mode, it will get information from arg.annoatation. + + Args: + node(ast.arguments): The ast argument Node + + Returns: + list[string]: A list of parameter names + """ + arg_names = [arg.arg for arg in node.args] + + if len(self.inputs_signature) != len(arg_names): + self.inputs_signature = [] + for arg in node.args: + arg_annotation = arg.annotation + if isinstance(arg_annotation, ast.Call): + self.inputs_signature.append( + ExprExecutor(self.variables_table.get()).exec( + arg_annotation + ) + ) + elif isinstance(arg_annotation, int): + if ( + -(2**21) <= arg_annotation + and arg_annotation <= 2**31 - 1 + ): + self.inputs_signature.append("i32") + elif ( + 2**63 <= arg_annotation + and arg_annotation <= 2**64 - 1 + ): + self.inputs_signature.append("u64") + else: + self.inputs_signature.append("i64") + elif isinstance(arg_annotation, float): + return self.inputs_signature.append("fp32") + else: + raise TypeError( + f'Unsupported type {type(arg_annotation)} for {arg_annotation}' + ) + + return arg_names + + def visit_For(self, node) -> ir.Expr: + """ + parse CINN Low Level IR For. + + Args: + node(ast.For): The ast For node + """ + for_ctx = ExprExecutor(self.variables_table.get()).exec(node.iter) + with self.variables_table: + with for_ctx as loop_var: + local_var_table = exec_assign( + target=node.target, source=loop_var + ) + for k, v in local_var_table.items(): + loop_var.rename(k) + self.variables_table.add(k, ir.Expr(v)) + self.visit_compound_statement(node.body) + + def visit_Assign(self, node): + """ + parse CINN Low Level IR Store. + + Args: + node(ast.Assign): The ast Assign node + + Returns: + ir.Expr, Points to the Expr of ir::ExprNode + """ + + if isinstance(node.value, ast.Call) and node_is_schedule(node.value): + return "no compute" + + assert ( + len(node.targets) == 1 + ), "Unsupport targets is a \ + list of nodes, like 'a = b = c'" + lhs = node.targets[0] + + # 1 parse RHS + rhs_expr = ExprExecutor(self.variables_table.get()).exec(node.value) + + # 2 parse LHS + # 2.1 Tensor + if isinstance(lhs, ast.Subscript): + expr_tensor = ExprExecutor(self.variables_table.get()).exec( + lhs.value + ) + if isinstance(lhs.slice, ast.Tuple): + expr_indices = [] + for idx in lhs.slice.elts: + expr_indices.append( + ExprExecutor(self.variables_table.get()).exec(idx) + ) + else: + expr_indices = [ + ExprExecutor(self.variables_table.get()).exec(lhs.slice) + ] + # TODO(6clc): Implement implicit type conversion (constant ->Expr) + if not isinstance(rhs_expr, ir.Expr): + rhs_expr = ir.Expr(rhs_expr) + ir.TensorStore(expr_tensor.Expr(), rhs_expr, expr_indices) + # 2.2 Attribute of Var + elif isinstance(lhs, ast.Attribute): + iter_var = self.visit(lhs.value) + setattr(iter_var.as_var_mutable(), lhs.attr, self.eval(node.value)) + return "no compute" + # 2.3 Var + else: + # TODO(6clc): We need to figure out a better way to + # handle the IterVar names and python variable names same + # Current only suport AxisMap function + local_var_table = exec_assign(target=lhs, source=rhs_expr) + if isinstance(lhs, ast.Tuple): + for k, v in local_var_table.items(): + v.as_var_ref().rename(k) + self.variables_table.add(k, v) + else: + for k, v in local_var_table.items(): + v[0].as_var_ref().rename(k) + self.variables_table.add(k, v[0]) + + def visit_Call(self, node): + if node_is_schedule(node): + return + self.generic_visit(node) + + def visit_If(self, node): + with self.variables_table: + with ir.IfContext( + ExprExecutor(self.variables_table.get()).exec(node.test) + ): + with ir.ThenContext(): + with self.variables_table: + self.visit_compound_statement(node.body) + if node.orelse: + with ir.ElseContext(): + with self.variables_table: + self.visit_compound_statement(node.body) + + def visit_With(self, node): + with self.variables_table: + with contextlib.ExitStack() as context_stack: + for item in node.items: + cur_ctx = ExprExecutor(self.variables_table.get()).exec( + item.context_expr + ) + cur_ctx = context_stack.enter_context(cur_ctx) + if item.optional_vars is not None: + local_var_table = exec_assign( + target=item.optional_vars, source=cur_ctx + ) + for k, v in local_var_table.items(): + self.variables_table.add(k, v) + body = self.visit_compound_statement(node.body) diff --git a/python/cinn/compiler/expr_executor.py b/python/cinn/compiler/expr_executor.py new file mode 100644 index 00000000000000..5e0b3f78e05434 --- /dev/null +++ b/python/cinn/compiler/expr_executor.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast + +from cinn import ir + +AST2CINN = { + ast.Add: ir.Add, + ast.Sub: ir.Sub, + ast.Mult: ir.Mul, + ast.Div: ir.Div, + ast.Mod: ir.Mod, + ast.And: ir.And, + ast.Or: ir.Or, + ast.USub: ir.Minus, + ast.Not: ir.Not, + ast.Eq: ir.EQ, + ast.NotEq: ir.NE, + ast.Lt: ir.LT, + ast.LtE: ir.LE, + ast.Gt: ir.GT, + ast.GtE: ir.GE, +} + + +class ExprExecutor: + def __init__(self, var_table): + self.var_table = var_table + self.tmp_value_count = 1 + + def exec(self, node): + ret = self.visit(node) + if isinstance(ret, ast.Name): + return self.var_table[ret.id] + if isinstance(ret, ast.Constant): + return ret.value + raise Exception(f"Error result type: {type(ret)}") + + def visit(self, node): + if isinstance(node, list): + return [self.visit(item) for item in node] + if isinstance(node, tuple): + return (self.visit(item) for item in node) + assert isinstance(node, ast.AST) + if isinstance(node, ast.Name): + return node + + if isinstance(node, ast.Constant): + return node + + if not isinstance(node, (ast.expr, ast.slice)): + # some nodes don't need to parse, such as ast.Load + return node + if isinstance(node, (ast.Lambda, ast.Starred)): + raise Exception("Current not suporrted: Lambda, Starred") + + cls_fields = {} + for field in node.__class__._fields: + attr = getattr(node, field) + if isinstance(attr, (ast.AST, tuple, list)): + cls_fields[field] = self.visit(attr) + else: + cls_fields[field] = attr + + node_type_name = f'eval_{type(node).__name__}' + if hasattr(self, node_type_name): + exec_func = getattr(self, node_type_name) + value = exec_func(cls_fields) + else: + new_node = node.__class__(**cls_fields) + ast.copy_location(new_node, node) + new_node = ast.Expression(new_node) + value = self.exec_expr(new_node) + return self.save_temp_value(value) + + def exec_expr(self, node): + if isinstance(node, ast.expr): + node = ast.Expression(body=node) + node = ast.fix_missing_locations(node) + exec = compile(node, filename="", mode="eval") + return eval(exec, self.var_table) + + def eval_BinOp(self, fields): + args = [self.exec_expr(fields["left"]), self.exec_expr(fields["right"])] + args = [ + ir.Expr(item) if not isinstance(item, ir.Expr) else item + for item in args + ] + return AST2CINN[type(fields["op"])].make(*args) + + def eval_UnaryOp(self, fields): + args = [self.exec_expr(fields["operand"])] + args = [ + ir.Expr(item) if not isinstance(item, ir.Expr) else item + for item in args + ] + return AST2CINN[type(fields["op"])].make(*args) + + def eval_Compare(self, fields): + assert ( + len(fields["ops"]) == 1 + ), "Only binary comparison symbols are supported. Expressions such as '1 <= a < 10' are not supported." + args = [ + self.exec_expr(fields["left"]), + self.exec_expr(fields["comparators"][0]), + ] + args = [ + ir.Expr(item) if not isinstance(item, ir.Expr) else item + for item in args + ] + return AST2CINN[type(fields["ops"][0])].make(*args) + + def save_temp_value(self, value): + name = f"__cinn_python_script_tmp_value_{self.tmp_value_count}" + self.tmp_value_count += 1 + self.var_table[name] = value + return ast.Name( + id=name, + ctx=ast.Load( + lineno=0, col_offset=0, end_lineno=None, end_col_offset=None + ), + lineno=0, + col_offset=0, + end_lineno=None, + end_col_offset=None, + ) + + +def exec_assign(target, source): + right_value_var_name = "__CINN_RIGHT_VALUE_VAR_NAME__" + local_var_table = {right_value_var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name(id=right_value_var_name, ctx=ast.Load()), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, local_var_table) + del local_var_table[right_value_var_name] + return local_var_table diff --git a/python/cinn/compiler/schedule_code_generator.py b/python/cinn/compiler/schedule_code_generator.py new file mode 100644 index 00000000000000..0f7cc68fea1644 --- /dev/null +++ b/python/cinn/compiler/schedule_code_generator.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast + +from cinn.schedule import IRSchedule + +from .expr_executor import ExprExecutor, exec_assign +from .utils import ( + VariableTable, + node_is_schedule, + node_is_schedule_block_context, +) + + +class ScheduleCodeGenerator(ast.NodeVisitor): + """ + Convert python ast to CINN Lower Level IR, + containing only the semantics of the schedule part + """ + + def __init__(self, fn, cinn_llir_func): + self.fn = fn + self.cinn_llir_func = cinn_llir_func + self.scheduler = IRSchedule.make(self.cinn_llir_func) + self.variable_table = VariableTable() + self.global_variable_table = VariableTable() + self.extra_scope = { + "ScheduleBlockVariable": ScheduleBlockVariable, + "scheduler": self.scheduler, + } + self.loop_var_stack = [] + self.block_stack = [] + self.sch_block_tmp_var_name = "__CINN_SCHEDULE_BLOCK_VAR_NAME__" + self.tmp_var_count = 1 + + def parse(self): + with self.variable_table, self.global_variable_table: + ast_node = self.fn.parse() + for k, v in self.fn.scope.items(): + self.variable_table.add(k, v) + for k, v in self.extra_scope.items(): + self.variable_table.add(k, v) + self.visit(ast_node) + return self.cinn_llir_func + + def visit_For(self, node): + assert isinstance( + node.target, ast.Name + ), "Current only support range() to make ForLoop" + with self.variable_table: + self.loop_var_stack.append(node.target) + self.generic_visit(node) + self.loop_var_stack.pop() + + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.visit(stmt) + + def visit_With(self, node): + with self.variable_table: + for item in node.items: + if isinstance( + item.context_expr, ast.Call + ) and not node_is_schedule_block_context(item.context_expr): + continue + # 1. replace ScheduleBlockContext to ScheduleBlockVariable + sch_ctx_node = item.context_expr + sch_block_node = ast.copy_location( + ast.Call( + func=ast.Name( + id="ScheduleBlockVariable", ctx=ast.Load() + ), + args=sch_ctx_node.args, + keywords=[], + starargs=None, + kwargs=None, + ), + item.context_expr, + ) + item.context_expr = sch_block_node + + # 2. store ScheduleBlockVariable node + sch_block = ExprExecutor(self.variable_table.get()).exec( + item.context_expr + ) + if item.optional_vars is None: + tmp_var_name = self.sch_block_tmp_var_name + str( + self.tmp_var_count + ) + sch_block_var_node = ast.Name( + id=tmp_var_name, ctx=ast.Store() + ) + item.optional_vars = sch_block_var_node + local_var_table = exec_assign( + target=item.optional_vars, source=sch_block + ) + # 3. Set the block's loop to its attritbute + sch_block.set_scheduler(self.scheduler) + self.block_stack.append(sch_block) + for k, v in local_var_table.items(): + self.variable_table.add(k, v) + self.global_variable_table.add(k, v) + for loop_var in self.loop_var_stack: + loop_var_value = ast.Attribute( + value=ast.Name(id=k, ctx=ast.Load()), + attr=loop_var.id, + ctx=ast.Load(), + ) + loop_var_value = ExprExecutor( + self.variable_table.get() + ).exec(loop_var_value) + for_loop_var_table = exec_assign( + loop_var, loop_var_value + ) + for ( + loop_var_k, + loop_var_v, + ) in for_loop_var_table.items(): + self.variable_table.add(loop_var_k, loop_var_v) + + body = self.visit_compound_statement(node.body) + + def visit_Assign(self, node): + if isinstance(node.value, ast.Call) and node_is_schedule(node.value): + sch_ret = self.exec_schedule_primitive(node.value) + local_var_table = exec_assign( + target=node.targets[0], source=sch_ret + ) + for k, v in local_var_table.items(): + self.variable_table.add(k, v) + return + self.generic_visit(node) + + def visit_Call(self, node): + if isinstance(node, ast.Call) and node_is_schedule(node): + self.exec_schedule_primitive(node) + return + + def exec_schedule_primitive(self, node): + # replace ScheduleBlockContext to ScheduleBlockVariable + sch_primitive = node + args = [ast.Name(id="scheduler", ctx=ast.Load()), *sch_primitive.args] + sch_primitive.args = args + all_variable_table = self.variable_table.get() + for k, v in self.global_variable_table.get().items(): + all_variable_table[k] = v + sch_ret = ExprExecutor(all_variable_table).exec(node) + + return sch_ret + + +class ScheduleBlockVariable: + def __init__(self, name): + self.name = name + self.scheduler = None + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + def __getattr__(self, k): + # TODO(6clc): Improve the error message of schedule, throw an exception to prompt the user when there is no block + if k == "block": + return self.scheduler.get_block(self.name) + else: + name2loops = self.scheduler.get_name2loops_dict(self.name) + return name2loops[k] diff --git a/python/cinn/compiler/utils.py b/python/cinn/compiler/utils.py new file mode 100644 index 00000000000000..64399d6b651961 --- /dev/null +++ b/python/cinn/compiler/utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast + +try: + from _collections import defaultdict +except ImportError: + pass + + +from cinn.schedule import IRSchedule + + +def node_is_schedule(node: ast.Call): + func_name = "" + if isinstance(node.func, ast.Name): + func_name = node.func.id + elif isinstance(node.func, ast.Attribute): + func_name = node.func.attr + if func_name == "print": + return True + + return getattr(IRSchedule, func_name, None) + + +def node_is_schedule_block_context(node: ast.Call): + if isinstance(node.func, ast.Name): + return node.Name == "ScheduleBlockContext" + if isinstance(node.func, ast.Attribute): + return node.func.attr == "ScheduleBlockContext" + return False + + +class VariableTable: + def __init__(self): + # var name added by current context + self.var_name_list = [] + # var name to var. Dtype is {string:list} + # list records the value assigned to each layer of context + self.name2value = defaultdict(list) + + def __enter__(self): + self.var_name_list.append([]) + return self + + def __exit__(self, ptype, value, trace) -> None: + # clear var assign in current context + if ptype is None and value is None: + var_names = self.var_name_list.pop() + for var_name in var_names: + self.name2value[var_name].pop() + if len(self.name2value[var_name]) == 0: + self.name2value.pop(var_name) + + def add(self, name, value, cover=False): + # TODO(6clc): to check value is equal + if cover and name in self.var_name_list[-1]: + self.name2value[name][-1] = value + else: + self.var_name_list[-1].append(name) + self.name2value[name].append(value) + + def get(self): + return {k: v[-1] for k, v in self.name2value.items()} diff --git a/python/cinn/framework.py b/python/cinn/framework.py index 5ea4ef7c4af0e4..2b58539cf5750c 100644 --- a/python/cinn/framework.py +++ b/python/cinn/framework.py @@ -13,6 +13,7 @@ # limitations under the License. from .core_api.framework import ( # noqa: F401 + Instruction, NodeAttr, Operator, OpValueType, diff --git a/python/cinn/ir/__init__.py b/python/cinn/ir/__init__.py index b05b23230074d1..537eae9b6277b2 100644 --- a/python/cinn/ir/__init__.py +++ b/python/cinn/ir/__init__.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .ir import sequential +from .ir_context import ( + IRBuilder, + IRContext, + ScheduleBlockContext, + LowerFuncContext, + ForContext, + IfContext, + ThenContext, + ElseContext, +) from ..core_api.ir import ( # noqa: F401 Add, And, @@ -120,6 +131,15 @@ _Module_, _Tensor_, _Var_, + _Buffer_, + Buffer, + ModuleExpr, + IrCompare, + IfThenElse, + Arg, + Sequential, + TensorStore, + AxisMap, ) diff --git a/python/cinn/ir/ir.py b/python/cinn/ir/ir.py new file mode 100644 index 00000000000000..72e7a95b9a3793 --- /dev/null +++ b/python/cinn/ir/ir.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cinn import ir + +from .ir_context import ForContext + + +def sequential(min, extent=None): + if extent is None: + extent = min + min = ir.Expr(0) + if not isinstance(min, ir.Expr): + min = ir.Expr(min) + if not isinstance(extent, ir.Expr): + extent = ir.Expr(extent) + return ForContext(min, extent) diff --git a/python/cinn/ir/ir_context.py b/python/cinn/ir/ir_context.py new file mode 100644 index 00000000000000..fe637edc791e48 --- /dev/null +++ b/python/cinn/ir/ir_context.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cinn import ir + +from .. import core_api + + +class IRBuilder: + def __init__(self): + self.ir_builder = core_api.ir.IRBuilder() + + def __enter__(self): + self.ir_builder.EnterWithContext() + return self + + def __exit__( + self, ptype, value, trace + ) -> None: # pylint: disable=unused-argument + if ptype is None and value is None: + self.ir_builder.ExitWithContext() + + def get(self): + return self.ir_builder.get_result() + + +class IRContext: + def __init__(self, ir_ctx): + self.ir_ctx = ir_ctx + + def __enter__(self): + self.ir_ctx.EnterWithContext() + + def __exit__(self, ptype, value, trace) -> None: + if ptype is None and value is None: + self.ir_ctx.ExitWithContext() + + +class ScheduleBlockContext(IRContext): + def __init__(self, name): + self.ir_ctx = core_api.ir.IRContext.MakeScheduleBlockContext(name) + + +class LowerFuncContext(IRContext): + def __init__(self, name): + self.ir_ctx = core_api.ir.IRContext.MakeLowerFunctionContext(name) + + +class ForContext(IRContext): + def __init__(self, min, extent): + self.ir_ctx = ir.Sequential(min, extent) + + def __enter__(self): + super().__enter__() + return self.ir_ctx.get_for_loop_var() + + +class IfContext(IRContext): + def __init__(self, expr): + self.ir_ctx = core_api.ir.IRContext.MakeIfContext(expr) + + +class ThenContext(IRContext): + def __init__(self): + self.ir_ctx = core_api.ir.IRContext.MakeThenContext() + + +class ElseContext(IRContext): + def __init__(self): + self.ir_ctx = core_api.ir.IRContext.MakeElseContext() diff --git a/python/cinn/runtime.py b/python/cinn/runtime/__init__.py similarity index 88% rename from python/cinn/runtime.py rename to python/cinn/runtime/__init__.py index 47bd7247479b7a..244567bd855c22 100644 --- a/python/cinn/runtime.py +++ b/python/cinn/runtime/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 CINN Authors. All Rights Reserved. +# Copyright (c) 2023 CINN Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .core_api.runtime import ( # noqa: F401 +from cinn.core_api.runtime import ( # noqa: F401 VoidPointer, cinn_arm_device, cinn_bool_t, @@ -66,3 +66,8 @@ seed, set_cinn_cudnn_deterministic, ) + +from .cinn_jit import CinnLowerLevelIrJit +from .module import Module + +__all__ = ["CinnLowerLevelIrJit", "Module"] diff --git a/python/cinn/runtime/cinn_jit.py b/python/cinn/runtime/cinn_jit.py new file mode 100644 index 00000000000000..7b85808593d625 --- /dev/null +++ b/python/cinn/runtime/cinn_jit.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ast +import functools +import inspect +import textwrap +from typing import Callable, Generic, Optional, TypeVar, Union, cast + +from .utils import inspect_function_scope + +T = TypeVar('T') + + +class CinnLowerLevelIrJit(Generic[T]): + def __init__(self, fn): + self.fn = fn + # function prototype + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def") :] + self.scope = inspect_function_scope(fn) + + # docs of warpped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + # Encapsulates the compile and run processes + self.run = self._make_launcher() + + def _make_launcher(self): + # Gets information about runtime input parameters + jit_input_args = ', '.join(arg_name for arg_name in self.arg_names) + lazy_compile = f""" +import cinn +def {self.fn.__name__}({jit_input_args}, target=cinn.common.DefaultHostTarget()): + from cinn.compiler import compile + jit_inputs = {', '.join([f'{arg}' for arg in self.arg_names])} + jit_inputs_signature = {{ i: self._convert_arg_type(arg) \ + for i, arg in enumerate(jit_inputs)}} + module = compile(self, jit_inputs_signature=jit_inputs_signature, arg_names={ + self.arg_names}, target=target) + module({jit_input_args}) + + return module + """ + scope = { + "self": self, + } + exec(lazy_compile, scope) + return scope[self.fn.__name__] + + def convert_to_llir(self): + from cinn.compiler import compile + + return compile(self, just_convert=True) + + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + return tree + + def __getitem__(self, target): + return cast( + T, functools.partial(cast(Callable, self.run), target=target) + ) + + def _convert_arg_type(self, arg): + # arg is a Tensor + if hasattr(arg, "dtype"): + return arg + # arg is a Var + else: + if isinstance(arg, int): + if -(2**21) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + raise TypeError(f'Unsupported type {type(arg)} for {arg}') + + def __str__(self): + return str(self.convert_to_llir()) + + +def to_cinn_llir( + fn: Optional[T] = None, +) -> Union[CinnLowerLevelIrJit[T]]: + def decorator(fn: T) -> CinnLowerLevelIrJit[T]: + return CinnLowerLevelIrJit(fn) + + if fn is not None: + return decorator(fn) + else: + return decorator diff --git a/python/cinn/runtime/data_array.py b/python/cinn/runtime/data_array.py new file mode 100644 index 00000000000000..374bd2a88e7db7 --- /dev/null +++ b/python/cinn/runtime/data_array.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from cinn import common, runtime + + +class DataArray: + """ + Provides Python encapsulation of the cinn_buffer_t + data interface in the CINN RunTime module. + """ + + def __init__( + self, + shape: list, + dtype: common.Type = common.Float(32), + data: runtime.cinn_buffer_t = None, + ) -> None: + self.shape = shape + self.dtype = dtype + self.data = data + + def to_numpy(self): + """ + Convert DataArray to numpy array + """ + cinn_dtype_to_np_dtype = {common.Float(32): "float32"} + if self.dtype.is_float(32, common.Type.specific_type_t.UNK): + np_dtype = np.float32 + np_arr = np.empty(self.shape, np_dtype) + assert np_arr.flags["C_CONTIGUOUS"] + self.data.copy_to(np_arr) + return np_arr + + @staticmethod + def from_numpy(np_array, target=common.DefaultHostTarget()): + """ + Create DataArray form numpy array + """ + assert isinstance(np_array, np.ndarray) + data = runtime.cinn_buffer_t(np_array, target) + dtype_np_to_common = { + "float": common.Float(32), + "float32": common.Float(32), + } + # TODO(6clc): Support float16 + dtype_np = str(np_array.dtype).split(".")[-1] + assert dtype_np in dtype_np_to_common.keys() + + return DataArray(np_array.shape, dtype_np_to_common[dtype_np], data) diff --git a/python/cinn/runtime/module.py b/python/cinn/runtime/module.py new file mode 100644 index 00000000000000..24a31691015944 --- /dev/null +++ b/python/cinn/runtime/module.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cinn +from cinn import framework +from cinn.backends import Compiler + + +class Module: + def __init__(self, llir_module, target, fn_name, arg_names): + self.arg_names = arg_names + self.fn_name = fn_name + self.compiler = Compiler.create(target) + self.compiler.build(llir_module) + self._instruction = framework.Instruction( + target, None, [], arg_names, fn_name + ) + + def __call__(self, *args): + name2pod = {} + for i, name in enumerate(self.arg_names): + if isinstance(args[i], cinn.runtime.data_array.DataArray): + name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i].data) + else: + name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i]) + + self._instruction.run(self.compiler, self.fn_name, name2pod) diff --git a/python/cinn/runtime/utils.py b/python/cinn/runtime/utils.py new file mode 100644 index 00000000000000..8df8cccc772d1c --- /dev/null +++ b/python/cinn/runtime/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + + +def get_func_global_vars(func): + if inspect.ismethod(func): + func = func.__func__ + + code = func.__code__ + global_vars = {} + if func.__closure__ is not None: + for k, v in zip(code.co_freevars, func.__closure__): + global_vars[k] = v.cell_contents + return global_vars + + +def inspect_function_scope(func): + scope = { + **func.__globals__, + **get_func_global_vars(func), + } + return scope diff --git a/python/cinn/schedule.py b/python/cinn/schedule.py new file mode 100644 index 00000000000000..c9cb004f1fb66f --- /dev/null +++ b/python/cinn/schedule.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .core_api.schedule import IRSchedule # noqa: F401 diff --git a/python/setup_cinn.py.in b/python/setup_cinn.py.in index beeeab8752cff2..753a1d30cd7ad3 100644 --- a/python/setup_cinn.py.in +++ b/python/setup_cinn.py.in @@ -171,7 +171,8 @@ packages = ["cinn", "cinn.auto_schedule.cost_model", "cinn.ir", "cinn.libs", - "cinn.version" + "cinn.version", + "cinn.runtime" ] with redirect_stdout(): diff --git a/test/cinn/CMakeLists.txt b/test/cinn/CMakeLists.txt index ca9989b745826d..3158c4372d8fdb 100644 --- a/test/cinn/CMakeLists.txt +++ b/test/cinn/CMakeLists.txt @@ -274,4 +274,40 @@ if(WITH_GPU) WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endforeach() + file( + GLOB CINN_RUNTIME_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "runtime/test_*.py") + + foreach(runtime_test_name ${EXCLUDE_RUNTIME}) + list(REMOVE_ITEM CINN_RUNTIME_TEST runtime/${runtime_test_name}.py) + endforeach() + + foreach(runtime_test_name ${CINN_RUNTIME_TEST}) + string(REGEX REPLACE ".py" "" runtime_test_name ${runtime_test_name}) + add_test( + NAME ${runtime_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} + python3 ${CMAKE_CURRENT_SOURCE_DIR}/${runtime_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + + file( + GLOB CINN_IR_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "ir/test_*.py") + + foreach(ir_test_name ${CINN_IR_TEST}) + string(REGEX REPLACE ".py" "" ir_test_name ${ir_test_name}) + add_test( + NAME ${ir_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} + python3 ${CMAKE_CURRENT_SOURCE_DIR}/${ir_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + endif() diff --git a/test/cinn/ir/test_llir_schedule_bind.py b/test/cinn/ir/test_llir_schedule_bind.py new file mode 100644 index 00000000000000..fd4373d338a94f --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_bind.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_bind_reduce(): + @to_cinn_llir + def reduce_sum(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))): + for i1 in range(1): + for j1 in range(4): + for k1 in range(256): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = 0.0 + for l1 in range(512): + with ir.ScheduleBlockContext("B"): + sch.bind(i1, "blockIdx.x") + sch.bind(j1, "threadIdx.z") + sch.bind(k1, "threadIdx.x") + vi1, vj1, vk1, vl1 = ir.AxisMap( + "SSSR", [i1, j1, k1, l1] + ) + B[vi1, vj1, vk1] = ( + B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] + ) + + print(reduce_sum) + + +if __name__ == "__main__": + test_bind_reduce() diff --git a/test/cinn/ir/test_llir_schedule_cache_read_write.py b/test/cinn/ir/test_llir_schedule_cache_read_write.py new file mode 100644 index 00000000000000..85badc819f8f55 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_cache_read_write.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_cache_read_elementwise(): + @to_cinn_llir + def elementwise_add_cache_read( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("B") as B_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + cached_a = sch.cache_read(A_block.block, 0, "global") + cached_b = sch.cache_read(B_block.block, 0, "local") + + assert_llir_equal(elementwise_add_cache_read, elementwise_add_cache_read) + + +def test_cache_write_elementwise(): + @to_cinn_llir + def elementwise_add_cache_write( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("B") as B_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + cached_a = sch.cache_write(A_block.block, 0, "global") + cached_b = sch.cache_write(B_block.block, 0, "local") + + # TODO(6clc): core dump + # assert_llir_equal(elementwise_add_cache_write, elementwise_add_cache_write) + + +if __name__ == "__main__": + test_cache_read_elementwise() + test_cache_write_elementwise() diff --git a/test/cinn/ir/test_llir_schedule_compute_at.py b/test/cinn/ir/test_llir_schedule_compute_at.py new file mode 100644 index 00000000000000..0f82786935b411 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_compute_at.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_compute_at_elementwise(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + sch.compute_at(A_block.block, i, False) + Y[i1, j1] = A[i1, j1] + 2.0 + + @to_cinn_llir + def elementwise_add_gt( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A"): + i1, j1 = ir.AxisMap("SS", [i, 0 + j]) + A[i1, j1] = X[i1, j1] * 2.0 + for k in range(128): + with ir.ScheduleBlockContext("Y"): + i2, k1 = ir.AxisMap("SS", [i, k]) + Y[i2, k1] = A[i2, k1] + 2.0 + + assert_llir_equal(elementwise_add, elementwise_add_gt) + + +def test_reverse_compute_at(): + @to_cinn_llir + def reverse_compute_at_tiled( + A: DataArray((128, 128)), + B: DataArray((128, 128)), + C: DataArray((128, 128)), + ): + for i0 in range(8): + for j0 in range(8): + for i1 in range(16): + for j1 in range(16): + with ir.ScheduleBlockContext("B") as B_block: + vi, vj = ir.AxisMap( + "SS", [i0 * 16 + i1, j0 * 16 + j1] + ) + B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("C") as C_block: + vi, vj = ir.AxisMap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch.reverse_compute_at(C_block.block, B_block.i1) + + @to_cinn_llir + def reverse_compute_at_tiled_gt( + A: DataArray((128, 128)), + B: DataArray((128, 128)), + C: DataArray((128, 128)), + ): + for i0 in range(8): + for j0 in range(8): + for i1 in range(16): + for j1 in range(16): + with ir.ScheduleBlockContext("B") as B_block: + vi, vj = ir.AxisMap( + "SS", [i0 * 16 + i1, j0 * 16 + j1] + ) + B[vi, vj] = A[vi, vj] * 2.0 + for j2 in range(16): + with ir.ScheduleBlockContext("C") as C_block: + vi, vj = ir.AxisMap( + "SS", [16 * i0 + i1, 16 * j0 + j2] + ) + C[vi, vj] = B[vi, vj] + 1.0 + + assert_llir_equal(reverse_compute_at_tiled, reverse_compute_at_tiled_gt) + + +if __name__ == '__main__': + test_compute_at_elementwise() + test_reverse_compute_at() diff --git a/test/cinn/ir/test_llir_schedule_compute_inline.py b/test/cinn/ir/test_llir_schedule_compute_inline.py new file mode 100644 index 00000000000000..a95d1dd8174495 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_compute_inline.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_compute_inline_elementwise(): + @to_cinn_llir + def elementwise_add_inline( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + block_a = sch.get_block("A") + sch.compute_inline(block_a) + + @to_cinn_llir + def elementwise_add_inline_gt( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = -(X[i1, j1] * 2.0) + 3.0 + + assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) + + +def test_reverse_compute_inline_elementwise(): + @to_cinn_llir + def elementwise_add_inline( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + sch.reverse_compute_inline(Y_block.block) + + @to_cinn_llir + def elementwise_add_inline_gt( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = -(X[i1, j1] * 2.0) + 3.0 + + assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) + + +if __name__ == "__main__": + test_compute_inline_elementwise() + test_reverse_compute_inline_elementwise() diff --git a/test/cinn/ir/test_llir_schedule_for_kind.py b/test/cinn/ir/test_llir_schedule_for_kind.py new file mode 100644 index 00000000000000..86ae663c31aac0 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_for_kind.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_elementwise_parallel(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = A[i1, j1] + 2.0 + sch.parallel(A_block.i) + + assert_llir_equal(elementwise_add, elementwise_add) + + +def test_elementwise_vectorize(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j0 in range(32): + for j1 in range(4): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1]) + Y[i1, j1] = A[i1, j1] + 2.0 + sch.vectorize(Y_block.j1, 1) + + assert_llir_equal(elementwise_add, elementwise_add) + + +def test_elementwise_unroll(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j0 in range(32): + for j1 in range(4): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1]) + Y[i1, j1] = A[i1, j1] + 2.0 + sch.unroll(Y_block.j1) + + print(elementwise_add) + assert_llir_equal(elementwise_add, elementwise_add) + + +if __name__ == "__main__": + test_elementwise_parallel() + test_elementwise_vectorize() + test_elementwise_unroll() diff --git a/test/cinn/ir/test_llir_schedule_fuse_split.py b/test/cinn/ir/test_llir_schedule_fuse_split.py new file mode 100644 index 00000000000000..8e8ca77c524570 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_fuse_split.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_fuse(): + @to_cinn_llir + def elementwise_fuse_assign_loop( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as block_y: + sch.fuse([i, j, k]) + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_fuse_assign_loop_gt( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(2097152): + with ir.ScheduleBlockContext("Y") as block_y: + i1_1, j1_1, k1_1 = ir.AxisMap( + "SSS", [(i / 128) / 128, (i / 128) % 128, i % 128] + ) + Y[i1_1, j1_1, k1_1] = X[i1_1, j1_1, k1_1] * 2.0 + + assert_llir_equal( + elementwise_fuse_assign_loop, elementwise_fuse_assign_loop_gt + ) + + +def test_split(): + @to_cinn_llir + def elementwise_split( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.i, factors=[2, 1, 64]) + sch.split(Y_block.j, factors=[4, 32]) + sch.split(Y_block.k, factors=[16, 8]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_split_inferred_factor( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.i, factors=[-1, 1, 64]) + sch.split(Y_block.j, factors=[4, -1]) + sch.split(Y_block.k, factors=[-1, 8]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + assert_llir_equal(elementwise_split, elementwise_split_inferred_factor) + + +def test_split_predicate(): + @to_cinn_llir + def elementwise_split_predicate( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.i, factors=[1000, 1, 64]) + sch.split(Y_block.j, factors=[4, 32]) + sch.split(Y_block.k, factors=[16, 8]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_split_predicate_gt( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(1000): + for i_0 in range(1): + for i_1 in range(64): + if ((64 * i) + ((64 * i_0) + i_1)) < 128: + for j in range(4): + for j_0 in range(32): + for k in range(16): + for k_0 in range(8): + with ir.ScheduleBlockContext("Y"): + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + (64 * i) + + ((64 * i_0) + i_1), + (32 * j) + j_0, + (8 * k) + k_0, + ], + ) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + assert_llir_equal( + elementwise_split_predicate, elementwise_split_predicate_gt + ) + print(elementwise_split_predicate) + + +if __name__ == "__main__": + test_fuse() + test_split() + test_split_predicate() diff --git a/test/cinn/ir/test_llir_schedule_reorder.py b/test/cinn/ir/test_llir_schedule_reorder.py new file mode 100644 index 00000000000000..00ca99388ba941 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_reorder.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_reorder_elementwise(): + @to_cinn_llir + def reorder_elementwise( + X: DataArray((64, 64, 64, 64)), Y: DataArray((64, 64, 64, 64)) + ): + for i in range(64): + for j in range(64): + for k in range(64): + for l in range(8): + with ir.ScheduleBlockContext("Y") as Y_block: + vi, vj, vk, vl = ir.AxisMap( + "SSSS", [i, j, k, 8 * l] + ) + Y[vi, vj, vk, vl] = X[vi, vj, vk, vl] * 2.0 + sch.reorder([Y_block.k, Y_block.l, Y_block.i]) + + @to_cinn_llir + def reorder_elementwise_gt( + X: DataArray((64, 64, 64, 64)), Y: DataArray((64, 64, 64, 64)) + ): + for k in range(64): + for j in range(64): + for l in range(8): + for i in range(64): + with ir.ScheduleBlockContext("Y"): + vi, vj, vk, vl = ir.AxisMap( + "SSSS", [i, j, k, 8 * l] + ) + Y[vi, vj, vk, vl] = X[vi, vj, vk, vl] * 2.0 + + assert_llir_equal(reorder_elementwise, reorder_elementwise_gt) + + +def test_reorder_overlapped(): + @to_cinn_llir + def reorder_overlapped(X: DataArray((28, 8)), Y: DataArray((28, 8))): + for i in range(12): + for j in range(4): + for k in range(4): + with ir.ScheduleBlockContext("Y"): + vi, vj = ir.AxisMap("SS", [i, j]) + sch.reorder([i, k, j]) + Y[vi, vj] = X[vi, vj] + 1.0 + + @to_cinn_llir + def reorder_overlapped_gt(X: DataArray((28, 8)), Y: DataArray((28, 8))): + for i in range(12): + for k in range(4): + for j in range(4): + with ir.ScheduleBlockContext("Y"): + vi, vj = ir.AxisMap("SS", [i, j]) + Y[vi, vj] = X[vi, vj] + 1.0 + + assert_llir_equal(reorder_overlapped, reorder_overlapped_gt) + + +if __name__ == '__main__': + test_reorder_elementwise() + test_reorder_overlapped() diff --git a/test/cinn/ir/test_llir_schedule_rfactor.py b/test/cinn/ir/test_llir_schedule_rfactor.py new file mode 100644 index 00000000000000..eca2bf384efe4d --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_rfactor.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_matmul(): + @to_cinn_llir + def matmul( + A: DataArray((128, 128)), + B: DataArray((128, 128)), + C: DataArray((128, 128)), + ): + for i0 in range(128): + for i1 in range(128): + with ir.ScheduleBlockContext("init"): + vi, vj = ir.AxisMap("SS", [i0, i1]) + C[vi, vj] = 0.0 + for i2_outer in range(4): + for i2_inner_outer in range(8): + for i2_inner_inner in range(4): + with ir.ScheduleBlockContext( + "compute" + ) as Compute_block: + vi, vj, vk = ir.AxisMap( + "SSR", + [ + i0, + i1, + i2_outer * 32 + + i2_inner_outer * 4 + + i2_inner_inner, + ], + ) + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + sch.rfactor(Compute_block.i2_inner_inner, 0) + + # TODO(6clc): iter_value not support complex reduce bindings + assert_llir_equal(matmul, matmul) + + +if __name__ == "__main__": + test_matmul() diff --git a/test/cinn/ir/test_llir_schedule_sequence.py b/test/cinn/ir/test_llir_schedule_sequence.py new file mode 100644 index 00000000000000..2cff0c650fd632 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_sequence.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_split_reorder_elementwise(): + @to_cinn_llir + def split_reorder_elementwise( + X: DataArray((1024, 1024)), + Y: DataArray((1024, 1024)), + Z: DataArray((1024, 1024)), + ): + for i in range(1024): + for j in range(1024): + for k in range(1024): + with ir.ScheduleBlockContext("Z"): + i_split_0, i_split_1, i_split_2, i_split_3 = sch.split( + i, factors=[2, 4, 64, 2] + ) + sch.reorder([i_split_2, i_split_0]) + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + Z[i1, j1] = Z[i1, j1] + X[i1, k] * Y[k, j1] + + @to_cinn_llir + def split_reorder_elementwise_gt( + X: DataArray((1024, 1024)), + Y: DataArray((1024, 1024)), + Z: DataArray((1024, 1024)), + ): + for i_1 in range(64): + for i_0 in range(4): + for i in range(2): + for i_2 in range(2): + for j in range(1024): + for k in range(1024): + with ir.ScheduleBlockContext("Z"): + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + (512 * i) + + ((128 * i_0) + ((2 * i_1) + i_2)), + j, + k, + ], + ) + Z[i1, j1] = Z[i1, j1] + ( + X[i1, k] * Y[k, j1] + ) + + assert_llir_equal(split_reorder_elementwise, split_reorder_elementwise_gt) + + +if __name__ == "__main__": + test_split_reorder_elementwise() diff --git a/test/cinn/runtime/test_launch.py b/test/cinn/runtime/test_launch.py new file mode 100644 index 00000000000000..af48d5a01f428e --- /dev/null +++ b/test/cinn/runtime/test_launch.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cinn +import numpy as np +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray + + +@to_cinn_llir +def bin_op_kernel(X, Y, Z): + for idx in range(10): + with ir.ScheduleBlockContext("Z"): + idx1 = ir.AxisMap("S", [idx]) + Z[idx1] = X[idx1] + Y[idx1] + + +def test_launch(): + N = 10 + X_np = np.random.random(N).astype(np.float32) + Y_np = np.random.random(N).astype(np.float32) + Z_np = np.zeros((N), dtype=np.float32) + target = cinn.common.DefaultNVGPUTarget() + X = DataArray.from_numpy(X_np, target) + Y = DataArray.from_numpy(Y_np, target) + Z = DataArray.from_numpy(Z_np, target) + + # compile and run + bin_op_kernel[target](X, Y, Z) + pred = Z.to_numpy() + gt = np.add(X_np, Y_np) + np.testing.assert_allclose(pred, gt) + + +if __name__ == "__main__": + import os + import sys + + print(sys.version) + print(os.getpid()) + # os.system("read REPLY") + + test_launch() diff --git a/test/cinn/runtime/test_reduce_cuda.py b/test/cinn/runtime/test_reduce_cuda.py new file mode 100644 index 00000000000000..e723a7720e5a02 --- /dev/null +++ b/test/cinn/runtime/test_reduce_cuda.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cinn +import cinn.schedule as sch +import numpy as np +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray + + +@to_cinn_llir +def reduce_sum(A, B): + for i1 in range(1): + for j1 in range(2): + for k1 in range(4): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = 0.0 + for l1 in range(8): + sch.bind(i1, "blockIdx.x") + sch.bind(j1, "threadIdx.z") + sch.bind(k1, "threadIdx.x") + with ir.ScheduleBlockContext("B"): + vi1, vj1, vk1, vl1 = ir.AxisMap( + "SSSR", [i1, j1, k1, l1] + ) + B[vi1, vj1, vk1] = ( + B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] + ) + + +def test_reduce_cuda(): + # prepare input and output array + d1 = 2 + d2 = 4 + d3 = 8 + a_np = np.random.rand(1, d1, d2, d3).astype("float32") + b_np = a_np.sum(axis=-1).astype("float32") + target = cinn.common.DefaultNVGPUTarget() + a = DataArray.from_numpy(a_np, target) + b = DataArray.from_numpy(np.zeros_like(b_np), target) + reduce_sum[target](a, b) + np.testing.assert_allclose(b.to_numpy(), b_np, rtol=1e-5, atol=1e-6) + + +if __name__ == "__main__": + test_reduce_cuda() diff --git a/test/cinn/utils/testing.py b/test/cinn/utils/testing.py new file mode 100644 index 00000000000000..b67432a17c189a --- /dev/null +++ b/test/cinn/utils/testing.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cinn.ir import IrCompare +from cinn.runtime import CinnLowerLevelIrJit + + +def assert_llir_equal( + llir1, llir2, allow_name_suffix_diff=True, only_compare_structure=True +): + comparer = IrCompare(allow_name_suffix_diff, only_compare_structure) + + if isinstance(llir1, CinnLowerLevelIrJit): + llir1_expr = llir1.convert_to_llir().body() + llir2_expr = llir2.convert_to_llir().body() + assert comparer.compare( + llir1_expr, llir2_expr + ), f'llir1: {llir1} \n llir2: {llir2}'