From da6b6372c09d680dfbd471c5900e49a74c0689e8 Mon Sep 17 00:00:00 2001 From: 6clc Date: Mon, 28 Aug 2023 09:13:43 +0800 Subject: [PATCH] fix(CINN-LLIR): Parse ast.Assign in a more complete way --- .../cinn/compiler/compute_code_generator.py | 122 +++++++++++++----- 1 file changed, 89 insertions(+), 33 deletions(-) diff --git a/python/cinn/compiler/compute_code_generator.py b/python/cinn/compiler/compute_code_generator.py index 885fe68d4d073..60e499435d156 100644 --- a/python/cinn/compiler/compute_code_generator.py +++ b/python/cinn/compiler/compute_code_generator.py @@ -13,6 +13,7 @@ # limitations under the License. import ast +from typing import Union from cinn import ir @@ -27,8 +28,7 @@ def __init__(self, function_name, inputs_signature): self.function_name = function_name self.inputs_signature = inputs_signature self.cinn_llir_func = None - self.left_value_scope = {} - self.local_variables = {} + self.variables_table = {} def visit_FunctionDef(self, node) -> None: """ @@ -44,6 +44,7 @@ def visit_FunctionDef(self, node) -> None: # 1. Construct args of function llir_args = [] for i, arg_name in enumerate(arg_names): + # Obj of Argument is ir::Buffer if hasattr(self.inputs_signature[i], "dtype"): llir_value = ir._Buffer_.make( "_" + arg_name, self.inputs_signature[i].dtype @@ -54,15 +55,21 @@ def visit_FunctionDef(self, node) -> None: tensor_shape = [ ir.Expr(dim) for dim in self.inputs_signature[i].shape ] + + # The computational logic of CINN is implemented through Tensor, + # so ir::_Tensor_ is stored in local variables llir_value = ir._Tensor_.make( arg_name, self.inputs_signature[i].dtype, tensor_shape, tensor_shape, ) + # Obj of Argument is ir::Var else: llir_value = ir.Var(arg_name) llir_args.append(ir.Argument(llir_value)) + # The computational logic of CINN is implemented through Expr, + # so ir::Expr is stored in local variables llir_value = ir.Expr(llir_value) self.set_value(arg_name, llir_value) @@ -87,14 +94,12 @@ def visit_compound_statement(self, stmts): def visit_arguments(self, node): arg_names = [] + # Just get the name of the arg, + # the properties of the arg are already stored in JIT Function. for arg in node.args: - arg_names += [self.visit(arg)] + arg_names += arg.arg return arg_names - def visit_arg(self, node): - ast.NodeVisitor.generic_visit(self, node) - return node.arg - def visit_For(self, node) -> ir.Expr: """ parse CINN Low Level IR For. @@ -105,6 +110,7 @@ def visit_For(self, node) -> ir.Expr: Returns: ir.Expr, Points to the Expr of ir::ExprNode """ + # 1. Parse the iter of the For loop iter_args = [self.visit(arg) for arg in node.iter.args] assert ( len(iter_args) <= 2 @@ -113,12 +119,15 @@ def visit_For(self, node) -> ir.Expr: ast_extent = iter_args[1] if len(iter_args) > 1 else iter_args[0] # TODO(6clc): support sub region's local variable + # AS code in `visit_FunctionDef`, store ir::Expr in local variables llir_var = ir.Var(node.target.id) llir_var_expr = ir.Expr(llir_var) self.set_value(node.target.id, llir_var_expr) llir_for_min = ir.Expr(ast_min) llir_for_extent = ir.Expr(ast_extent) + + # 2. Parse the body of the For loop llir_for_body = self.visit_compound_statement(node.body) llir_for_body = ir.Block.make(llir_for_body) for_expr = ir.For.make( @@ -127,30 +136,25 @@ def visit_For(self, node) -> ir.Expr: return for_expr def visit_Name(self, node): + # Store Node if type(node.ctx) == ast.Store: - if node.id in self.local_variables: - return self.local_variables[node.id] + if node.id in self.variables_table: + return self.variables_table[node.id] return node.id # Load Node assert ( - node.id in self.local_variables + node.id in self.variables_table ), f"{node.id} is not defined in context" - return self.local_variables[node.id] - - def visit_BinOp(self, node): - cinn_tensor_l, indexs_l = self.visit(node.left) - lhs = ir.Load.make(cinn_tensor_l, indexs_l) - cinn_tensor_r, indexs_r = self.visit(node.right) - rhs = ir.Load.make(cinn_tensor_r, indexs_r) - ast2cinn = {ast.Add: ir.Add} - return ast2cinn[ast.Add].make(lhs, rhs) + return self.variables_table[node.id] def visit_Subscript(self, node): - lhs_tensor = self.visit(node.value) - idxs = [ + expr_tensor = self.visit(node.value) + indices = [ self.visit(node.slice), ] - return lhs_tensor.Expr(), idxs + if type(node.ctx) == ast.Load: + return ir.Load.make(expr_tensor, indices) + return expr_tensor, indices def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] @@ -170,21 +174,73 @@ def visit_Assign(self, node): ir.Expr, Points to the Expr of ir::ExprNode """ - _names = [] - for target in node.targets: - _names += [self.visit(target)] assert ( - len(_names) == 1 + len(node.targets) == 1 ), "Unsupport targets is a \ - list of nodes, like 'a, b = c'" - names = _names[0] - value = self.visit(node.value) + list of nodes, like 'a = b = c'" + lhs = node.targets[0] - return ir.Store.make(names[0], value, names[1]) + # 1 parse RHS + rhs_expr = self.eval_expression(node.value) - def set_value(self, name, value): - self.left_value_scope[name] = value - self.local_variables[name] = value + # 2 parse LHS + assert isinstance( + lhs, ast.Subscript + ), f'Currently only tensor assignment expressions are supported. {lhs.value} is not a Tensor' + expr_tensor, expr_indices = self.visit(lhs) + return ir.Store.make(expr_tensor, rhs_expr, expr_indices) def visit_Constant(self, node): return ir.Expr(node.value) + + def eval_expression(self, node): + """ + Parse Expr expression composed of AST nodes + """ + args = [] + if isinstance(node, ast.BinOp): + args = [node.left, node.right] + elif isinstance(node, ast.UnaryOp): + args = [node.operand] + elif isinstance(node, ast.Compare): + assert ( + len(node.ops) == 1 + ), "Only binary comparison symbols are supported. Expressions such as '1 <= a < 10' are not supported." + args = [node.left, *node.comparators] + elif isinstance(node, ast.BoolOp): + args = node.values + elif isinstance(node, ast.Call): + args = node.args + else: + raise NotImplementedError( + f'The parse data type: {node} is not currently supported' + ) + for i, arg in enumerate(args): + args[i] = self.visit(arg) + + ast2cinn = { + # Binary Op + ast.Add: ir.Add, + ast.Sub: ir.Sub, + ast.Mult: ir.Mul, + ast.Div: ir.Div, + ast.Mod: ir.Mod, + ast.And: ir.And, + ast.Or: ir.Or, + # Comparator Op + ast.Eq: ir.EQ, + ast.NotEq: ir.NE, + ast.Lt: ir.LT, + ast.LtE: ir.LE, + ast.Gt: ir.GT, + ast.GtE: ir.GE, + # Unary Op + ast.USub: ir.Minus, + ast.Not: ir.Not, + } + return ast2cinn[type(node.op)].make(*args) + + def set_value(self, name, value: Union[ir.Tensor, ir.Var]): + if isinstance(value, ir.Tensor): + value = value.Expr() + self.variables_table[name] = value