Skip to content

Commit

Permalink
cinn(py-dsl): all test can pass
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Sep 21, 2023
1 parent f293535 commit 41e6637
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 176 deletions.
8 changes: 3 additions & 5 deletions python/cinn/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def ast_to_llir(fn, inputs_signature):
cinn_llir_func = llir_compute_generator.parse()

# 2. Parse CINN Schedule
llir_schedule_generator = ScheduleCodeGenerator(cinn_llir_func)
llir_schedule_generator.visit(fn.parse())
return llir_schedule_generator.cinn_llir_func
llir_schedule_generator = ScheduleCodeGenerator(fn, cinn_llir_func)
return llir_schedule_generator.parse()


def llir_to_runtime_module(llir_func, target, function_name, arg_names):
Expand All @@ -44,8 +43,7 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):
if isinstance(fn, CinnLowerLevelIrJit):
llir_func = ast_to_llir(fn, jit_inputs_signature)
else:
raise Exception(
"Current Only support compile from CinnLowerLevelIrJit")
raise Exception("Current Only support compile from CinnLowerLevelIrJit")

if just_convert:
return llir_func
Expand Down
48 changes: 28 additions & 20 deletions python/cinn/compiler/compute_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
# limitations under the License.

import ast
from typing import List, Union
import contextlib


from cinn import ir
from cinn.runtime.data_array import DataArray

from .utils import node_is_schedule, VariableTable
from .expr_executor import ExprExecutor, exec_assign
from .utils import VariableTable, node_is_schedule


class ComputeCodeGenerator(ast.NodeVisitor):
Expand Down Expand Up @@ -111,8 +108,11 @@ def visit_arguments(self, node):
for arg in node.args:
arg_annotation = arg.annotation
if isinstance(arg_annotation, ast.Call):
self.inputs_signature.append(ExprExecutor(
self.variables_table.get()).exec(arg_annotation))
self.inputs_signature.append(
ExprExecutor(self.variables_table.get()).exec(
arg_annotation
)
)
elif isinstance(arg_annotation, int):
if (
-(2**21) <= arg_annotation
Expand Down Expand Up @@ -146,7 +146,8 @@ def visit_For(self, node) -> ir.Expr:
with self.variables_table:
with for_ctx as loop_var:
local_var_table = exec_assign(
target=node.target, source=loop_var)
target=node.target, source=loop_var
)
for k, v in local_var_table.items():
loop_var.rename(k)
self.variables_table.add(k, ir.Expr(v))
Expand Down Expand Up @@ -178,16 +179,19 @@ def visit_Assign(self, node):
# 2 parse LHS
# 2.1 Tensor
if isinstance(lhs, ast.Subscript):
expr_tensor = ExprExecutor(
self.variables_table.get()).exec(lhs.value)
expr_tensor = ExprExecutor(self.variables_table.get()).exec(
lhs.value
)
if isinstance(lhs.slice, ast.Tuple):
expr_indices = []
for idx in lhs.slice.elts:
expr_indices.append(ExprExecutor(
self.variables_table.get()).exec(idx))
expr_indices.append(
ExprExecutor(self.variables_table.get()).exec(idx)
)
else:
expr_indices = [ExprExecutor(
self.variables_table.get()).exec(lhs.slice)]
expr_indices = [
ExprExecutor(self.variables_table.get()).exec(lhs.slice)
]
# TODO(6clc): Implement implicit type conversion (constant ->Expr)
if not isinstance(rhs_expr, ir.Expr):
rhs_expr = ir.Expr(rhs_expr)
Expand All @@ -213,13 +217,15 @@ def visit_Assign(self, node):
self.variables_table.add(k, v[0])

def visit_Call(self, node):
func_name = node.func.attr
if node_is_schedule(node) is not None:
return "no compute"
if node_is_schedule(node):
return
self.generic_visit(node)

def visit_If(self, node):
with self.variables_table:
with ir.IfContext(ExprExecutor(self.variables_table.get()).exec(node.test)):
with ir.IfContext(
ExprExecutor(self.variables_table.get()).exec(node.test)
):
with ir.ThenContext():
with self.variables_table:
self.visit_compound_statement(node.body)
Expand All @@ -232,12 +238,14 @@ def visit_With(self, node):
with self.variables_table:
with contextlib.ExitStack() as context_stack:
for item in node.items:
cur_ctx = ExprExecutor(
self.variables_table.get()).exec(item.context_expr)
cur_ctx = ExprExecutor(self.variables_table.get()).exec(
item.context_expr
)
cur_ctx = context_stack.enter_context(cur_ctx)
if item.optional_vars is not None:
local_var_table = exec_assign(
target=item.optional_vars, source=cur_ctx)
target=item.optional_vars, source=cur_ctx
)
for k, v in local_var_table.items():
self.variables_table.add(k, v)
body = self.visit_compound_statement(node.body)
232 changes: 142 additions & 90 deletions python/cinn/compiler/schedule_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from cinn.schedule import IRSchedule

from .utils import node_is_schedule
from .expr_executor import ExprExecutor, exec_assign
from .utils import (
VariableTable,
node_is_schedule,
node_is_schedule_block_context,
)


class ScheduleCodeGenerator(ast.NodeVisitor):
Expand All @@ -25,102 +30,149 @@ class ScheduleCodeGenerator(ast.NodeVisitor):
containing only the semantics of the schedule part
"""

def __init__(self, cinn_llir_func):
def __init__(self, fn, cinn_llir_func):
self.fn = fn
self.cinn_llir_func = cinn_llir_func
self.scheduler = IRSchedule.make(self.cinn_llir_func)
self.sch_seq = []
self.name2loops = {}

def visit_Subscript(self, node):
"""
save block information
"""
if type(node.ctx) != ast.Store:
return

for sch_node in self.sch_seq:
block_name2loops = self.scheduler.get_name2loops_dict(node.value.id)
for k, v in block_name2loops.items():
self.name2loops[k] = v

# schedule node is ast.Call or ast.Assign
sch_call_node = (
sch_node if isinstance(sch_node, ast.Call) else sch_node.value
)

sch_name = (
sch_call_node.func.id
if isinstance(sch_call_node.func, ast.Name)
else sch_call_node.func.attr
)
sch_args = [self.eval(item) for item in sch_call_node.args]

sch_keywords = {
kw.arg: self.eval(kw.value) for kw in sch_call_node.keywords
}

ret = getattr(self.scheduler, sch_name)(*sch_args, **sch_keywords)

if isinstance(sch_node, ast.Assign):
assert (
len(sch_node.targets) == 1
), "Unsupport targets is a \
list of nodes, like 'a = b = c'"
var_name = self.visit(sch_node.targets[0])
if not isinstance(var_name, list):
var_name = [var_name]
for k, v in zip(var_name, ret):
self.name2loops[k] = v

self.sch_seq = []
self.name2loops = {}
self.variable_table = VariableTable()
self.global_variable_table = VariableTable()
self.extra_scope = {
"ScheduleBlockVariable": ScheduleBlockVariable,
"scheduler": self.scheduler,
}
self.loop_var_stack = []
self.block_stack = []
self.sch_block_tmp_var_name = "__CINN_SCHEDULE_BLOCK_VAR_NAME__"
self.tmp_var_count = 1

def parse(self):
with self.variable_table, self.global_variable_table:
ast_node = self.fn.parse()
for k, v in self.fn.scope.items():
self.variable_table.add(k, v)
for k, v in self.extra_scope.items():
self.variable_table.add(k, v)
self.visit(ast_node)
return self.cinn_llir_func

def visit_For(self, node):
assert isinstance(
node.target, ast.Name
), "Current only support range() to make ForLoop"
with self.variable_table:
self.loop_var_stack.append(node.target)
self.generic_visit(node)
self.loop_var_stack.pop()

def visit_compound_statement(self, stmts):
for stmt in stmts:
self.visit(stmt)

def visit_With(self, node):
with self.variable_table:
for item in node.items:
if isinstance(
item.context_expr, ast.Call
) and not node_is_schedule_block_context(item.context_expr):
continue
# 1. replace ScheduleBlockContext to ScheduleBlockVariable
sch_ctx_node = item.context_expr
sch_block_node = ast.copy_location(
ast.Call(
func=ast.Name(
id="ScheduleBlockVariable", ctx=ast.Load()
),
args=sch_ctx_node.args,
keywords=[],
starargs=None,
kwargs=None,
),
item.context_expr,
)
item.context_expr = sch_block_node

# 2. store ScheduleBlockVariable node
sch_block = ExprExecutor(self.variable_table.get()).exec(
item.context_expr
)
if item.optional_vars is None:
tmp_var_name = self.sch_block_tmp_var_name + str(
self.tmp_var_count
)
sch_block_var_node = ast.Name(
id=tmp_var_name, ctx=ast.Store()
)
item.optional_vars = sch_block_var_node
local_var_table = exec_assign(
target=item.optional_vars, source=sch_block
)
# 3. Set the block's loop to its attritbute
sch_block.set_scheduler(self.scheduler)
self.block_stack.append(sch_block)
for k, v in local_var_table.items():
self.variable_table.add(k, v)
self.global_variable_table.add(k, v)
for loop_var in self.loop_var_stack:
loop_var_value = ast.Attribute(
value=ast.Name(id=k, ctx=ast.Load()),
attr=loop_var.id,
ctx=ast.Load(),
)
loop_var_value = ExprExecutor(
self.variable_table.get()
).exec(loop_var_value)
for_loop_var_table = exec_assign(
loop_var, loop_var_value
)
for (
loop_var_k,
loop_var_v,
) in for_loop_var_table.items():
self.variable_table.add(loop_var_k, loop_var_v)

body = self.visit_compound_statement(node.body)

def visit_Assign(self, node):
if isinstance(node.value, ast.Call) and node_is_schedule(node.value):
self.sch_seq.append(node)
sch_ret = self.exec_schedule_primitive(node.value)
local_var_table = exec_assign(
target=node.targets[0], source=sch_ret
)
for k, v in local_var_table.items():
self.variable_table.add(k, v)
return
self.generic_visit(node)

def visit_Call(self, node):
if not node_is_schedule(node):
if isinstance(node, ast.Call) and node_is_schedule(node):
self.exec_schedule_primitive(node)
return
self.sch_seq.append(node)

def visit_Tuple(self, node):
elts = [self.visit(x) for x in node.elts]
return elts

def visit_Name(self, node):
return node.id

def eval(self, node):
return getattr(self, f'eval_{type(node).__name__}')(node)

def eval_List(self, node):
return [self.eval(item) for item in node.elts]

def eval_Tuple(self, node):
return [self.eval(item) for item in node.elts]

def eval_Constant(self, node):
if node.value == "init":
return self.scheduler.get_block(node.value)
return node.value

def eval_UnaryOp(self, node):
return eval(
compile(ast.Expression(body=node), filename='', mode='eval')
)

def eval_Name(self, node):
try:
if node.id in self.name2loops:
return self.name2loops[node.id]
else:
return self.scheduler.get_block(node.id)
except:
raise Exception(
f'No matching block and loop was found for {node.id}. \
Current loops are {self.name2loops.keys()}. \
Current lower ir is {self.cinn_llir_func}.'
)

def exec_schedule_primitive(self, node):
# replace ScheduleBlockContext to ScheduleBlockVariable
sch_primitive = node
args = [ast.Name(id="scheduler", ctx=ast.Load()), *sch_primitive.args]
sch_primitive.args = args
all_variable_table = self.variable_table.get()
for k, v in self.global_variable_table.get().items():
all_variable_table[k] = v
sch_ret = ExprExecutor(all_variable_table).exec(node)

return sch_ret


class ScheduleBlockVariable:
def __init__(self, name):
self.name = name
self.scheduler = None

def set_scheduler(self, scheduler):
self.scheduler = scheduler

def __getattr__(self, k):
# TODO(6clc): Improve the error message of schedule, throw an exception to prompt the user when there is no block
if k == "block":
return self.scheduler.get_block(self.name)
else:
name2loops = self.scheduler.get_name2loops_dict(self.name)
return name2loops[k]
Loading

0 comments on commit 41e6637

Please sign in to comment.