Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Python DSL of CINN IR #56393

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/cinn/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ set(srcs
pe.cc
frontend.cc
framework.cc
utils.cc)
utils.cc
schedule.cc)

if(WITH_CUDA)
message(STATUS "Compile core_api with CUDA support")
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/pybind/bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ PYBIND11_MODULE(core_api, m) {
"framework", "namespace cinn::hlir::framework, CINN framework");
py::module utils =
m.def_submodule("utils", "namespace cinn::utils, CINN framework");
py::module schedule = m.def_submodule(
"schedule", "namespace cinn::ir::schedule, CINN Schedule");

BindRuntime(&runtime);
BindCommon(&common);
Expand All @@ -53,6 +55,7 @@ PYBIND11_MODULE(core_api, m) {
BindFrontend(&frontend);
BindFramework(&framework);
BindUtils(&utils);
BindSchedule(&schedule);
}

} // namespace cinn::pybind
1 change: 1 addition & 0 deletions paddle/cinn/pybind/bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ void BindPE(pybind11::module *m);
void BindFrontend(pybind11::module *m);
void BindFramework(pybind11::module *m);
void BindUtils(pybind11::module *m);
void BindSchedule(pybind11::module *m);

} // namespace cinn::pybind
19 changes: 19 additions & 0 deletions paddle/cinn/pybind/framework.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/common/cinn_value.h"
#include "paddle/cinn/frontend/interpreter.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
Expand Down Expand Up @@ -211,5 +212,23 @@ void BindFramework(pybind11::module *m) {
CINN_NOT_IMPLEMENTED
}
});

py::class_<Instruction> instruction(*m, "Instruction");
instruction
.def(py::init<const Target &,
Scope *,
const std::vector<std::string> &,
const std::vector<std::string> &,
const std::string &>())
.def("run",
[](Instruction &self,
backends::Compiler &compiler,
const std::string fn_name,
std::map<std::string, cinn_pod_value_t> &name_to_pod) {
auto fn_ptr = compiler.Lookup(fn_name);
self.Finalize();
self.SetLoweredFunc(fn_ptr);
self.Run(&name_to_pod);
});
}
} // namespace cinn::pybind
4 changes: 2 additions & 2 deletions paddle/cinn/pybind/ir/ir_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class IRContextNode : public common::Object {
};

/**
* The lifecycle of RAII resource management for IRContextNode
* The life cycle of RAII resource management for IRContextNode
* is determined at the Python.
*/
class IRContext {
Expand Down Expand Up @@ -215,7 +215,7 @@ class IRBuilderNode : public common::Object {
};

/**
* The lifecycle of RAII resource management for IRBuilderNode
* The life cycle of RAII resource management for IRBuilderNode
* is determined at the Python.
*/
class IRBuilder {
Expand Down
16 changes: 15 additions & 1 deletion paddle/cinn/pybind/lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
#include "paddle/cinn/backends/codegen_c.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/buffer.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/pybind/bind_utils.h"

Expand Down Expand Up @@ -148,7 +151,18 @@ void BindModule(py::module *m) {

py::class_<ir::Module::Builder> builder(module, "Builder");
builder.def(py::init<const std::string &, const common::Target &>())
.def("add_function", &ir::Module::Builder::AddFunction)
.def("add_function",
[](ir::Module::Builder &self, ir::LoweredFunc func) {
// TODO(6clc): optimize by register of backend passs
if (self.GetTargetArch() == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
auto func_expr = Expr(func);
ir::SetCudaAxisInfo(&func_expr);
optim::OptimizeExprGPU(&(func->body));
#endif
}
self.AddFunction(func);
})
.def("add_buffer", &ir::Module::Builder::AddBuffer)
.def("build", &ir::Module::Builder::Build);
}
Expand Down
63 changes: 61 additions & 2 deletions paddle/cinn/pybind/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@
#include <cstring>
#include <memory>

#include "paddle/cinn/common/common.h"
#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/cinn/runtime/flags.h"

#ifdef CINN_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>

#include "paddle/cinn/backends/cuda_util.h"
#endif

namespace py = pybind11;
namespace cinn::pybind {
namespace {
Expand Down Expand Up @@ -66,6 +74,49 @@ cinn_buffer_t *CreateBufferFromNumpy(py::array data,
return buffer;
}

// TODO(6clc): Implement device of gpu like cinn_x86_device_impl.cc in runtime.
cinn_buffer_t *CreateBufferFromNumpy(
py::array data,
common::Target target = common::DefaultHostTarget(),
int align = 0) {
if (target == common::DefaultHostTarget()) {
return CreateBufferFromNumpy(data, cinn_x86_device);
} else if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
std::vector<int> shape;
std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
auto *buffer = new cinn_buffer_t();
buffer->device = cinn_nvgpu_device;
buffer->memory_size = data.nbytes();
CUDA_CALL(cudaMalloc(&buffer->memory, data.nbytes()));
CUDA_CALL(cudaMemcpy(
buffer->memory, data.data(), data.nbytes(), cudaMemcpyHostToDevice));
return buffer;
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif
} else {
CINN_NOT_IMPLEMENTED
}
}

void BufferCopyTo(const cinn_buffer_t &buffer, py::array array) {
void *array_data = array.mutable_data();
if (buffer.device == cinn_x86_device) {
std::memcpy(array_data, buffer.memory, array.nbytes());
} else if (buffer.device == cinn_nvgpu_device) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaMemcpy(
array_data, buffer.memory, array.nbytes(), cudaMemcpyDeviceToHost));
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif

} else {
CINN_NOT_IMPLEMENTED
}
}

py::array BufferHostMemoryToNumpy(cinn_buffer_t &buffer) { // NOLINT
py::dtype dt;
if (buffer.type == cinn_int32_t()) {
Expand Down Expand Up @@ -162,6 +213,7 @@ void BindCinnRuntime(py::module *m) {
.value("cinn_x86_device", cinn_x86_device)
.value("cinn_opencl_device", cinn_opencl_device)
.value("cinn_arm_device", cinn_arm_device)
.value("cinn_nvgpu_device", cinn_nvgpu_device)
.export_values();

py::enum_<cinn_buffer_kind_t> cinn_buffer_kind(*m, "cinn_buffer_kind_t");
Expand Down Expand Up @@ -220,10 +272,17 @@ void BindCinnRuntime(py::module *m) {
.def("set_flag", &cinn_buffer_t::set_flag)
// Python methods
.def("numpy", &BufferHostMemoryToNumpy)
.def(py::init(&CreateBufferFromNumpy),
.def(py::init(py::overload_cast<py::array, cinn_device_kind_t, int>(
&CreateBufferFromNumpy)),
arg("data"),
arg("device"),
arg("align") = 0);
arg("align") = 0)
.def(py::init(py::overload_cast<py::array, common::Target, int>(
&CreateBufferFromNumpy)),
arg("data"),
arg("target"),
arg("align") = 0)
.def("copy_to", &BufferCopyTo);

m->def("cinn_x86_device_interface", &cinn_x86_device_interface)
.def("cinn_buffer_load_float32", &cinn_buffer_load_float32)
Expand Down
151 changes: 151 additions & 0 deletions paddle/cinn/pybind/schedule.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <string>

#include "paddle/cinn/ir/schedule/ir_schedule.h"

namespace py = pybind11;

namespace cinn::pybind {

void BindSchedule(py::module *m) {
py::class_<ir::IRSchedule> ir_schedule(*m, "IRSchedule");
ir_schedule
.def(py::init<const ir::ModuleExpr &,
utils::LinearRandomEngine::StateType,
bool,
utils::ErrorMessageLevel>(),
py::arg("modexpr"),
py::arg("rand_seed") = -1,
py::arg("debug_flag") = false,
py::arg("err_msg_level") = utils::ErrorMessageLevel::kGeneral)
.def_static(
"make",
[](ir::LoweredFunc &ir_func) {
ir::ModuleExpr *module_expr = new ir::ModuleExpr({ir_func->body});
auto scheduler = std::make_unique<ir::IRSchedule>(*module_expr);
return scheduler;
})
.def("fuse",
py::overload_cast<const std::vector<Expr> &>(&ir::IRSchedule::Fuse))
.def("split",
py::overload_cast<const Expr &, const std::vector<int> &>(
&ir::IRSchedule::Split),
py::arg("loop"),
py::arg("factors"))
.def("compute_at",
py::overload_cast<const Expr &, const Expr &, bool>(
&ir::IRSchedule::ComputeAt),
py::arg("block"),
py::arg("loop"),
py::arg("keep_unit_loops") = false)
.def("simple_compute_at",
py::overload_cast<const Expr &, const Expr &>(
&ir::IRSchedule::SimpleComputeAt),
py::arg("block"),
py::arg("loop"))
.def("reverse_compute_at",
py::overload_cast<const Expr &, const Expr &, bool>(
&ir::IRSchedule::ReverseComputeAt),
py::arg("block"),
py::arg("loop"),
py::arg("keep_unit_loops") = false)
.def("cache_read",
py::overload_cast<const Expr &, int, const std::string &>(
&ir::IRSchedule::CacheRead))
.def("cache_write",
py::overload_cast<const Expr &, int, const std::string &>(
&ir::IRSchedule::CacheWrite))
.def("sync_threads",
py::overload_cast<const Expr &, bool>(&ir::IRSchedule::SyncThreads),
py::arg("ir_node"),
py::arg("after_node") = true)
.def("set_buffer",
py::overload_cast<Expr &, const std::string &, bool>(
&ir::IRSchedule::SetBuffer),
py::arg("block"),
py::arg("memory_type"),
py::arg("fixed") = false)
.def("reorder",
py::overload_cast<const std::vector<Expr> &>(
&ir::IRSchedule::Reorder))
.def("parallel",
py::overload_cast<const Expr &>(&ir::IRSchedule::Parallel))
.def("vectorize",
py::overload_cast<const Expr &, int>(&ir::IRSchedule::Vectorize))
.def("unroll", py::overload_cast<const Expr &>(&ir::IRSchedule::Unroll))

.def("compute_inline",
py::overload_cast<const Expr &>(&ir::IRSchedule::ComputeInline))
.def("reverse_compute_inline",
py::overload_cast<const Expr &>(
&ir::IRSchedule::ReverseComputeInline))
.def("bind", &ir::IRSchedule::Bind)
.def("copy_transform_and_loop_info",
py::overload_cast<const Expr &, const Expr &>(
&ir::IRSchedule::CopyTransformAndLoopInfo))
.def("rfactor",
py::overload_cast<const Expr &, int>(&ir::IRSchedule::Rfactor))
.def("annotate",
py::overload_cast<const Expr &,
const std::string &,
const ir::attr_t &>(&ir::IRSchedule::Annotate))
.def("unannotate",
py::overload_cast<Expr &, const std::string &>(
&ir::IRSchedule::Unannotate))
.def("flatten_loops",
py::overload_cast<const std::vector<Expr> &, const bool>(
&ir::IRSchedule::FlattenLoops),
py::arg("loops"),
py::arg("force_flat") = false)
.def("sample_perfect_tile",
py::overload_cast<const Expr &, int, int, const std::vector<int> &>(
&ir::IRSchedule::SamplePerfectTile),
py::arg("loop"),
py::arg("n"),
py::arg("max_innermost_factor"),
py::arg("decision") = std::vector<int>())
.def("sample_categorical",
py::overload_cast<const std::vector<int> &,
const std::vector<float> &,
const std::vector<int> &>(
&ir::IRSchedule::SampleCategorical),
py::arg("candidates"),
py::arg("probs"),
py::arg("decision") = std::vector<int>())
.def("get_module",
py::overload_cast<>(&ir::IRSchedule::GetModule, py::const_))
.def("get_root_block", &ir::IRSchedule::GetRootBlock)
.def("get_block",
py::overload_cast<const std::string &>(&ir::IRSchedule::GetBlock,
py::const_))
.def("get_all_blocks",
py::overload_cast<>(&ir::IRSchedule::GetAllBlocks, py::const_))
.def("get_loops",
py::overload_cast<const std::string &>(&ir::IRSchedule::GetLoops,
py::const_))
.def("get_name2loops_dict",
[](const ir::IRSchedule &self, const std::string &block_name) {
std::vector<ir::Expr> loops = self.GetLoops(block_name);
std::map<std::string, ir::Expr> name2loops;
for (const ir::Expr &loop : loops) {
name2loops[loop.As<ir::For>()->loop_var->name] = loop;
}
return name2loops;
});
}
} // namespace cinn::pybind
5 changes: 5 additions & 0 deletions paddle/cinn/pybind/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/utils/random_engine.h"

namespace py = pybind11;

Expand Down Expand Up @@ -69,6 +71,9 @@ void BindUtils(py::module *m) {
"type",
[](HostEvent &self) -> const EventType & { return self.type_; },
[](HostEvent &self, const EventType &v) { self.type_ = v; });

py::class_<utils::LinearRandomEngine>(*m, "LinearRandomEngine");
py::class_<utils::ErrorMessageLevel>(*m, "ErrorMessageLevel");
}

} // namespace pybind
Expand Down
Loading