diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 8c8aee7f0a3c2..9424f6dc30f29 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -89,8 +89,6 @@ class StructuralEqual : public BaseValueEqual { * \return The comparison result. */ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; - - TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) const; }; /*! diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 1440701dbeecc..9fcf510b70a82 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -231,9 +231,4 @@ bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) con return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } -bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, - bool map_free_vars) const { - return RemapVarSEqualHandler(false).Equal(lhs, rhs, map_free_vars); -} - } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 566bd860202ab..0275a893e1f4f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -182,9 +182,6 @@ bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } const auto* param = attrs.as(); - if (param == nullptr) { - return false; - } const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index cdd7555dcfeed..1f30b6800fce0 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -59,9 +59,6 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } const auto* param = attrs.as(); - if (param == nullptr) { - return false; - } if (tensor_tuple->fields[0].as()) { return false; } diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index 54752e387989f..f06246667a8ba 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -36,8 +36,8 @@ * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. * - * Note: this pass can only be used with functions where the input/output types are a - * combination of TupleTypes, TensorTypes, ADTs, and non-nested FuncTypes + * Note: this pass can only be used with functions where the input/output types are + * a combination of TupleTypes and TensorTypes * * This pass optimizes 6 ops: * - add @@ -47,652 +47,142 @@ * - zeros * - zeros_like * - * This module level pass adds a new "GradCell" version datatype for each existing datatype. - * This is the case to propogate the new GradCell datatype through ADTs such as Lists. - * For each function, a new function is created that accepts the "GradCell" type of the arguments - * of the original function. That is, inputs to the function are converted to their - * GradCell-version, passed to the newly created "GradCell_Function". The output is then necessarily - * converted from the GradCell version to the original return type. + * This pass makes use of three visitor. The most important one visits the entire function, + * one is used for wrap inputs and one to unwrap outputs. * - * To support ADTs, we use functions that convert between an instance of an ADT to its - * respective GradCell version - * by matching constructors to the constructor of the "GradCell" datatype. + * For example: + * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] * - * A transformation function is required for different type arguments. - * For example the ADT List may be List[int32] or List[List[int32]], which should be handled - * separately. + * After this pass + * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] * - * This pass uses 4 primary mutators: - * - LazyGradientInitializer to create the "GradCell_Function" of a given function. - * - GradCellWrapper mutates expr into its respective GradCell expr - * - GradCellWrapper mutates expr into its respective non-GradCell expr - * - ADTTransform creates a ADT for each unique ADT + * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ #include #include #include #include -#include #include -#include - #include "let_list.h" namespace tvm { namespace relay { -// prefix of name of GradCell version ADT -const char GradCell_Header[] = "_GradCell_"; -// prefix of transformation function for converting ADT to GradCell version -const char GradCell_TransFunc[] = "_GradCell_TransFunc_"; -// prefix of transformation function for converting GradCell version ADT to normal -const char GradCell_ReverseTransFunc[] = "_GradCell_ReverseTransFunc_"; -// prefix of copy of function that operates on GradCell types -const char GradCell_Func[] = "_GradCell_Func_"; - -struct TypeCallHash { - size_t operator()(const TypeCall& typecall) const { return ObjectHash()(typecall->func); } -}; - -/*! - * \brief Check if two ADT instances are equal, - * check for dataflow equivalence allow for mapping between TypeVars - * i.e GradCell[TypeVar(A)] = GradCell[TypeVar(B)] - */ -struct TypeCallEqual { - bool operator()(const TypeCall& l, const TypeCall& r) const { - if (!(l->func.same_as(r->func))) { - return false; - } - - if (l->args.size() != r->args.size()) { - return false; - } - - for (size_t i = 0; i < l->args.size(); i++) { - if (!tvm::StructuralEqual()(l->args[i], r->args[i], true)) { - return false; - } - } - - return true; - } -}; - /*! - * \brief ADTTransform creates a new ADT named - * GradCell_Header + name_hint for each unique ADT. + * \brief Visitor appropriately wraps tensors with Raw constructor + * + * Recursively looks at the type of the expression (TensorType or TupleType are only supported for + * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if + * TupleType */ -class ADTTransform : public TypeMutator, public PatternMutator { +class InputVisitor : public ExprFunctor { public: - explicit ADTTransform(IRModule module) : module_(module) {} - - Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } - - Type VisitType_(const TensorTypeNode* op) final { - GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); - tvm::Array args; - args.push_back(GetRef(op)); - return TypeCall(gradCell, args); - } - - Type VisitType_(const GlobalTypeVarNode* op) final { - GlobalTypeVar t = GetRef(op); - if (op->kind == kAdtHandle) { - if (adt_mapping_.count(t) != 0) { - return adt_mapping_.at(t); - } + explicit InputVisitor(IRModule module) : module_(module) {} - TypeData adt = module_->LookupTypeDef(t); - this->VisitType(adt); - - return adt_mapping_.at(t); - } - - return GetRef(op); + Expr VisitExpr_(const VarNode* op, const Type& t) final { + std::cout << op->type_annotation << std::endl; + return WrapExpr(GetRef(op), op->type_annotation); } - Type VisitType_(const TypeDataNode* op) final { - auto type_data = GetRef(op); - std::string transformed_adt_name = GradCell_Header + op->header->name_hint; - - // add new ADT to map to handle recursive definitions - GlobalTypeVar new_adt = GlobalTypeVar(transformed_adt_name, op->header->kind); - adt_mapping_[type_data->header] = new_adt; - reverse_adt_mapping_[new_adt] = type_data->header; - - // define transformed ADT - Array constructors; - for (Constructor con : op->constructors) { - Array inputs; - for (Type t : con->inputs) { - inputs.push_back(this->VisitType(t)); - } - Constructor transformed_cons = Constructor(GradCell_Header + con->name_hint, inputs, new_adt); - constructors.push_back(transformed_cons); - } - - TypeData new_datatype = TypeData(new_adt, op->type_vars, constructors); - module_->AddTypeDef(new_adt, new_datatype); - return std::move(new_datatype); - } - - Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } - - Constructor VisitConstructor(const Constructor& c) final { - this->VisitType(c->belong_to); - return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, - GradCell_Header + c->name_hint); - } - - /*! - * \brief Given a transformed ADT, returned the original ADT. - * Useful for GradCellUnWrapper which needs to map transformed ADT constructors - * to the original ADT constructors. - * - * \param transformed_adt_handle GlobalTypeVar of "GradCell-version" of ADT - * \return ADT - */ - GlobalTypeVar GetReverseADT(GlobalTypeVar transformed_adt_handle) { - auto it = reverse_adt_mapping_.find(transformed_adt_handle); - - // reverse mapping should always be found - CHECK(it != reverse_adt_mapping_.end()) << "Reverse mapping of ADT transformation not found"; - return it->second; + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return WrapExpr(GetRef(op), t); } private: - // Module IRModule module_; - // ADT -> transformed ADT - std::unordered_map adt_mapping_; - // transformed ADT -> ADT - std::unordered_map reverse_adt_mapping_; -}; -/*! - * \brief Helper for TypeCallMutator. - * Replace TypeVar with type arguments - */ -class TypeVarSolver : public TypeMutator { - public: - explicit TypeVarSolver( - const std::unordered_map& type_var_map, - const std::unordered_map& type_call_map) - : type_var_map_(type_var_map), type_call_map_(type_call_map) {} - Type VisitType_(const TypeVarNode* op) final { - TypeVar type = GetRef(op); - - if (type_call_map_.count(type) != 0) { - // recursively visit Type argument to replace possible nested TypeVar - return VisitType(type_call_map_.at(type)); - } - - if (type_var_map_.count(type) != 0) { - return type_var_map_.at(type); + Expr WrapExpr(const Expr expr, const Type& type) { + if (type.as()) { + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); + } else if (auto* type_anno = type.as()) { + tvm::Array fields; + for (size_t i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + } + Expr tuple = Tuple(fields); + return tuple; } - return std::move(type); + return expr; } - - private: - // type vars to unique type vars - std::unordered_map type_var_map_; - // TypeCall arguments to ADT - std::unordered_map type_call_map_; }; /*! - * \brief Find all TypeVars within the arguments of a TypeCallNode and create a mapping - * of the TypeVars to new TypeVars + * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors + * + * Recursively looks at the type of the expression + * and either use the FromGradCell function if TypeCall to GradCell + * or unfold and recursively visit if TupleType */ -class TypeCallMutator : public TypeVisitor { +class OutputVisitor : public ExprFunctor { public: - // TypeVars within TypeCallNode - Array args; - // unique TypeVars - Array params; - explicit TypeCallMutator(IRModule module, const TypeCallNode* op) : module_(module) { - for (Type t : op->args) { - // visit each type argument - VisitType(t); - } - for (auto const& x : type_var_map) { - args.push_back(x.first); - params.push_back(x.second); - } - } + explicit OutputVisitor(IRModule module) : module_(module) {} - /*! - * \brief Replace ADT type vars with TypeCall arguments - * and replace type vars with unique typevars - * - * \param t TypeCall - * \param map TypeVar of ADT -> type argument - * - * \return type after replacing ADT TypeVar with arguments and replacing any - * free type vars with uniquely generated typevars - */ - - Type InputType(Type t, const std::unordered_map& map) { - return TypeVarSolver(type_var_map, map).VisitType(t); + Expr VisitExpr_(const CallNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); } - void VisitType_(const TypeVarNode* op) final { - TypeVar tv = GetRef(op); - if (type_var_map.count(tv) == 0) { - TypeVar replacement = TypeVar(tv->name_hint + "_", tv->kind); - type_var_map.insert({tv, replacement}); - } + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); } private: IRModule module_; - // TypeVar in argument -> TypeVar of polymorphic function - std::unordered_map type_var_map; -}; - -typedef class GradCellUnWrapper GradCellUnWrapper; - -/*! - * \brief Mutate a given expression into its "GradCell-version". - * TensorTypes are wrapped with the Raw constructor of GradCell. - * TupleTypes are recursively visited. - * ADTTypes are converted to its appropriate transformed ADT - * FuncTypes are wrapped with a function that appropriately wraps/unwraps input and output - */ -class GradCellWrapper : public ExprFunctor, - public TypeMutator { - public: - explicit GradCellWrapper(IRModule module, ADTTransform* adt_transformer) - : module_(module), adt_transformer_(adt_transformer), unique(0) {} - Expr VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; - Expr VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; - - private: - // Module - IRModule module_; - // ADTTransform - ADTTransform* adt_transformer_; - // TypeCall -> Function to transform an ADT Instance into GradCell version - std::unordered_map adt_wrapper_map_; - // TypeVar of ADT call -> Type argument - std::unordered_map type_var_map; - // append to prefix to create unique function names for ADT wrapper functions - int64_t unique; - - Expr WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper); - // Return function to wrap ADT - Expr GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, - GradCellUnWrapper* unwrapper); - Type VisitType_(const GlobalTypeVarNode* op) final; - Type VisitType_(const TensorTypeNode* op) final; -}; - -/*! - * \brief Mutate a given "GradCell-version" expression into its nonGradCell-version. - * TypeCalls to GradCell are wrapped with FromGradCell function - * TupleTypes are recursively visited. - * Transformed ADTs are converted to its appropriate normal ADT - */ -class GradCellUnWrapper : public ExprFunctor, public TypeMutator { - public: - explicit GradCellUnWrapper(IRModule module, ADTTransform* adt_transformer) - : module_(module), adt_transformer_(adt_transformer), unique(0) {} - Expr VisitExpr_(const VarNode* op, const Type& t) final; - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final; - Expr VisitExpr_(const CallNode* op, const Type& t) final; - Expr VisitExpr_(const TupleNode* op, const Type& t) final; - Expr VisitExpr_(const ConstantNode* op, const Type& t) final; - - private: - // Module - IRModule module_; - // ADTTransform - ADTTransform* adt_transformer_; - // TypeCall -> Function an GradCell_ADT into ADT - std::unordered_map adt_unwrapper_map_; - // TypeVar of GradCell_ADT call -> Type argument - std::unordered_map type_var_map; - // create unique strings for ADT unwrapper functions - int64_t unique; - - Expr UnwrapExpr(const Expr expr, const Type& type); - // Return function to unwrap ADT - Expr GetReverseADTFunction(const TypeCallNode* op, TypeCallMutator& type_args); - Type VisitType_(const TypeCallNode* op) final; - Type VisitType_(const GlobalTypeVarNode* op) final; -}; - -/* GradCellWrapper */ -Expr GradCellWrapper::VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) { - return WrapExpr(GetRef(op), op->type_annotation, unwrapper); -} - -Expr GradCellWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t, - GradCellUnWrapper* unwrapper) { - return WrapExpr(GetRef(op), t, unwrapper); -} - -Expr GradCellWrapper::VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) { - return WrapExpr(GetRef(op), t, unwrapper); -} - -Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper) { - if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); - } - - if (auto* type_anno = type.as()) { - tvm::Array fields; - for (size_t i = 0; i < type_anno->fields.size(); i++) { - const Type& t = type_anno->fields[i]; - // recursively visit each item of tuple - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t, unwrapper)); - } - Expr tuple = Tuple(fields); - return tuple; - } - - if (auto* type_anno = type.as()) { - // create GradCell_ADT if not already created - adt_transformer_->VisitType(type_anno->func); - // find all type vars within type_anno - // to handle polymorphic functions - auto tvs = TypeCallMutator(module_, type_anno); - - return Call(GetADTFunction(type_anno, tvs, unwrapper), {expr}, Attrs(), tvs.args); - } - - if (auto* type_anno = type.as()) { - // to handle functions, we need to create a new function - // that handles GradCell version input and outputs GradCell version types - Array funcVars; - Array args; - for (Type t : type_anno->arg_types) { - Type visited = this->VisitType(t); - Var v = Var("v", visited); - funcVars.push_back(v); - // unwrap arguments - args.push_back(unwrapper->VisitExpr(v, visited)); - } - // call original expr with unwrapped arguments - Call call = Call(expr, args); - // wrap results of the call - Expr result = this->WrapExpr(call, type_anno->ret_type, unwrapper); - // return new function with GradCell-version types, wrapping original function - return Function(funcVars, result, this->VisitType(type_anno->ret_type), type_anno->type_params); - } - - return expr; -} - -Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, - GradCellUnWrapper* unwrapper) { - auto type = GetRef(op); - GlobalTypeVar adt_handle = Downcast(op->func); - if (adt_wrapper_map_.count(type) != 0) { - // ADT already wrapped previously - return adt_wrapper_map_.at(type); - } - - // handle recursive ADT which require recursive calls to transform - GlobalVar func_var = GlobalVar(GradCell_Header + std::string(GradCell_TransFunc) + - adt_handle->name_hint + std::to_string(unique++)); - adt_wrapper_map_[type] = func_var; - - TypeData adt_data = module_->LookupTypeDef(adt_handle); - TypeData new_adt_data = module_->LookupTypeDef(GradCell_Header + adt_handle->name_hint); - // solve for input type to wrap ADT function - for (size_t i = 0; i < adt_data->type_vars.size(); i++) { - type_var_map[adt_data->type_vars[i]] = op->args[i]; - } - auto input_type = type_args.InputType(type, type_var_map); - - CHECK(adt_data->constructors.size() == new_adt_data->constructors.size()) - << "ADT and transformed ADT have different number of constructors"; - - /* - * Pattern match each constructor of the ADT to the respective constructor - * in the transformed ADT. PatternVars then need to be recursively wrapped, - * and passed as argument to the constructor of the transformed ADT - */ - Array clauses; - for (size_t i = 0; i < adt_data->constructors.size(); i++) { - // get Constructor to pattern match against - Array patternVars; - Array c_args; - Constructor c = adt_data->constructors[i]; - for (Type t : c->inputs) { - // solve for type of PatternVar - Type pattern_var_type = type_args.InputType(t, type_var_map); - Var v = Var("var", pattern_var_type); - patternVars.push_back(PatternVar(v)); - // recursively wrap - c_args.push_back(this->VisitExpr(v, pattern_var_type, unwrapper)); - } - Pattern p = PatternConstructor(c, patternVars); - // return Constructor of new ADT with wrapped arguments - Expr e = Call(new_adt_data->constructors[i], c_args); - - clauses.push_back(Clause(p, e)); - } - - Var v = Var("v", input_type); - Expr match = Match(v, clauses); - - Function func = Function({v}, match, this->VisitType(input_type), type_args.params); - module_->AddUnchecked(func_var, func); - return std::move(func); -} - -Type GradCellWrapper::VisitType_(const GlobalTypeVarNode* op) { - GlobalTypeVar t = GetRef(op); - if (op->kind == kAdtHandle) { - return adt_transformer_->VisitType(t); - } - - return GetRef(op); -} - -Type GradCellWrapper::VisitType_(const TensorTypeNode* op) { - GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); - tvm::Array args; - args.push_back(GetRef(op)); - return TypeCall(gradCell, args); -} - -/* GradCellUnWrapper */ -Expr GradCellUnWrapper::VisitExpr_(const CallNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::VisitExpr_(const VarNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), op->type_annotation); -} - -Expr GradCellUnWrapper::VisitExpr_(const TupleNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::VisitExpr_(const ConstantNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::UnwrapExpr(const Expr expr, const Type& type) { - if (auto* type_call = type.as()) { - if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - // if TypeCall to GradCell, simply wrap with FromGradCell function - return Call(module_->GetGlobalVar("FromGradCell"), {expr}, Attrs(), type_call->args); - } - - // convert transformed ADT to ADT - auto tvs = TypeCallMutator(module_, type_call); - return Call(GetReverseADTFunction(type_call, tvs), {expr}, Attrs(), tvs.args); - } - - if (auto* type_anno = type.as()) { - tvm::Array fields; - for (size_t i = 0; i < type_anno->fields.size(); i++) { - // recursively unwrap items of tuple - const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); - } - Expr tuple = Tuple(fields); - return tuple; - } - return expr; -} - -Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, TypeCallMutator& type_args) { - TypeCall type = GetRef(op); - GlobalTypeVar transformed_adt_handle = Downcast(op->func); - GlobalTypeVar adt_handle = adt_transformer_->GetReverseADT(transformed_adt_handle); - - // sanity check - CHECK(std::string(transformed_adt_handle->name_hint).rfind(GradCell_Header, 0) == 0) - << "Output ADT is not a transformed ADT"; - - if (adt_unwrapper_map_.count(type)) { - // transformed ADT unwrapped previously - return adt_unwrapper_map_.at(type); - } - - // handle recursive ADTs - GlobalVar func_var = GlobalVar(GradCell_Header + std::string(GradCell_ReverseTransFunc) + - adt_handle->name_hint + std::to_string(unique++)); - adt_unwrapper_map_[type] = func_var; - - TypeData adt_data = module_->LookupTypeDef(adt_handle); - TypeData transformed_adt_data = module_->LookupTypeDef(transformed_adt_handle); - - CHECK(adt_data->type_vars.size() == transformed_adt_data->type_vars.size()) - << "ADT and transformed ADT have different # of type args"; - - // solve for TypeVars of ADT to solve for input type of function - for (size_t i = 0; i < transformed_adt_data->type_vars.size(); i++) { - type_var_map[adt_data->type_vars[i]] = op->args[i]; - } - auto input_type = type_args.InputType(type, type_var_map); - - CHECK(adt_data->constructors.size() == transformed_adt_data->constructors.size()) - << "ADT and transformed ADT have different number of constructors"; - - // use same logic as wrapping expression - // Pattern match with each Constructor of the transformed ADT, - // return respective Constructor with arguments of unwrapped PatternVars - Array clauses; - for (size_t i = 0; i < transformed_adt_data->constructors.size(); i++) { - // Get Constructor of transformed ADT - Array patternVars; - Array c_args; - Constructor c = transformed_adt_data->constructors[i]; - for (Type t : c->inputs) { - // solve for type of pattern var - Type pattern_var_type = type_args.InputType(t, type_var_map); - Var v = Var("var", pattern_var_type); - // bind PatternVar to Var passed to constructor - patternVars.push_back(PatternVar(v)); - // recursively unwrap - c_args.push_back(this->VisitExpr(v, pattern_var_type)); + Expr UnwrapExpr(const Expr expr, const Type& type) { + if (auto* type_call = type.as()) { + if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { + return Call(module_->GetGlobalVar("FromGradCell"), {expr}); + } + return expr; + } else if (auto* type_anno = type.as()) { + tvm::Array fields; + for (size_t i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + } + Expr tuple = Tuple(fields); + return tuple; } - Pattern p = PatternConstructor(c, patternVars); - // Call appropriate Constructor - Expr e = Call(adt_data->constructors[i], c_args); - clauses.push_back(Clause(p, e)); + return expr; } +}; - Var v = Var("v", input_type); - Expr match = Match(v, clauses); - - Function func = Function({v}, match, this->VisitType(input_type), type_args.params); - module_->AddUnchecked(func_var, func); - return std::move(func); -} - -Type GradCellUnWrapper::VisitType_(const TypeCallNode* op) { - if (op->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return op->args[0]; - } - return TypeMutator::VisitType_(op); -} - -Type GradCellUnWrapper::VisitType_(const GlobalTypeVarNode* op) { - GlobalTypeVar t = GetRef(op); - if (op->kind == kAdtHandle) { - return adt_transformer_->GetReverseADT(t); - } - - return GetRef(op); -} - -class LazyGradientInitializer : public ExprMutator, public TypeMutator, public PatternMutator { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: explicit LazyGradientInitializer(IRModule module) : module_(module) { - // setup - adt_transformer_ = new ADTTransform(module_); - grad_cell_wrapper_ = new GradCellWrapper(module_, adt_transformer_); - grad_cell_unwrapper_ = new GradCellUnWrapper(module_, adt_transformer_); - - // import GradCell and GradCell functions module_->ImportFromStd("gradient.rly"); - - // ignore these functions when transforming - GlobalVar from_grad_cell = module_->GetGlobalVar("FromGradCell"); - GlobalVar mul_grad_cell = module_->GetGlobalVar("MultiplyGradCell"); - GlobalVar add_grad_cell = module_->GetGlobalVar("AddGradCell"); - - func_map_[from_grad_cell] = from_grad_cell; - func_map_[mul_grad_cell] = mul_grad_cell; - func_map_[add_grad_cell] = add_grad_cell; } /*! - * \brief Given a global function, create new global function - * that mirrors the functionality however using GradCell type. - * Original function will wrap inputs, call the mirrored function, unwrap the ouput, - * and return. + * \brief apply LazyGradientInit transformation and wrap function + * so that function type stays the same + * + * input/output types should only be a combination of TupleTypes and TensorTypes */ - BaseFunc VisitGlobalVar(const GlobalVar& gv) { - auto base_func = module_->Lookup(gv); - if (auto* e = base_func.as()) { - auto f = GetRef(e); - if (func_map_.count(gv) == 0) { - // create GlobalVar handle for function - func_map_[gv] = GlobalVar(GradCell_Func + gv->name_hint); - } - GlobalVar func_var = func_map_.at(gv); - if (module_->ContainGlobalVar(func_var->name_hint)) { - // transformed function already contained in IRModule, return - return module_->Lookup(func_var); - } - // create transformed function and add definition to IRModule - auto* transformed = ExprMutator::Mutate(f).as(); - module_->AddUnchecked(func_var, GetRef(transformed)); - - // wrap inputs of Tensor type using GradCellWrapper class - tvm::Array args; - for (Var var : f->params) { - Expr wrappedInput = - grad_cell_wrapper_->VisitExpr(var, var->checked_type(), grad_cell_unwrapper_); - args.push_back(wrappedInput); - } - Expr transformedExpr = Call(func_var, args); + Expr Transform(const Expr& e) { + auto* f = (e).as(); + auto* transformed = this->Mutate(e).as(); + + if (e.same_as(GetRef(transformed))) { + return GetRef(transformed); + } - // unwrap outputs of GradCell type into Tensor type using OutputVisitor class - Expr tensorOutput = grad_cell_unwrapper_->VisitExpr(transformedExpr, transformed->ret_type); - return Function(f->params, tensorOutput, f->ret_type, f->type_params); + // wrap inputs of Tensor type using InputVisitor class + tvm::Array args; + for (Var var : f->params) { + Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); + args.push_back(wrappedInput); } - throw std::runtime_error("GlobalVar does not map to a function"); + Expr transformedExpr = Call(GetRef(transformed), args); + + // unwrap outputs of GradCell type into Tensor type using OutputVisitor class + Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); + return Function(f->params, tensorOutput, f->ret_type, Array()); } Expr VisitExpr_(const ConstantNode* op) final { @@ -736,124 +226,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor - return grad_cell_wrapper_->VisitExpr(result, call_node->checked_type(), grad_cell_unwrapper_); - } - - if (auto* op = (call_node->op).as()) { - // create "GradCell-version" of ADT if not already created - adt_transformer_->VisitType(op->belong_to); - // call Constructor of transformed ADT - Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, - GradCell_Header + op->name_hint); - Array args; - for (Expr e : call_node->args) { - args.push_back(this->VisitExpr(e)); - } - - Array type_args; - for (Type t : call_node->type_args) { - type_args.push_back(this->VisitType(t)); - } - return Call(c, args, Attrs(), type_args); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } - + // not an op return ExprMutator::VisitExpr_(call_node); } - Expr VisitExpr_(const ConstructorNode* op) final { - return module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, - GradCell_Header + op->name_hint); - } - - Expr VisitExpr_(const IfNode* op) final { - auto true_b = VisitExpr(op->true_branch); - auto false_b = VisitExpr(op->false_branch); - - // guard is bool type which will become GradCell[bool], so necessary to unwrap - auto guard = - grad_cell_unwrapper_->VisitExpr(VisitExpr(op->cond), VisitType(op->cond->checked_type())); - return If(guard, true_b, false_b); - } - - Expr VisitExpr_(const VarNode* op) final { - auto var = GetRef(op); - if (var_map_.count(var) != 0) { - return var_map_.at(var); - } - - return ExprMutator::VisitExpr_(op); - } - - Expr VisitExpr_(const GlobalVarNode* op) final { - // GlobalVar is a handle to a global function - GlobalVar gv = GetRef(op); - if (func_map_.count(gv) == 0) { - // create handle to transformed function - func_map_[gv] = GlobalVar(GradCell_Func + op->name_hint); - } - return func_map_.at(gv); - } - Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } - Type VisitType_(const GlobalTypeVarNode* op) final { - GlobalTypeVar t = GetRef(op); - if (module_->GetGlobalTypeVar("GradCell").same_as(t)) { - // if GradCell type, do nothing - return std::move(t); - } - if (op->kind == kAdtHandle) { - // handle to ADT, define GradCell version of ADT is not already created - return adt_transformer_->VisitType(t); - } - - return std::move(t); - } - - Var VisitVar(const Var& v) final { - // used for PatternMutator - if (var_map_.count(v) == 0) { - var_map_.insert(std::pair(v, Var(v->name_hint(), VisitType(v->type_annotation)))); - } - return var_map_.at(v); - } - - Type VisitType_(const TensorTypeNode* op) final { + Type VisitType_(const TensorTypeNode* op) { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; args.push_back(GetRef(op)); return TypeCall(gradCell, args); } - Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } - - Constructor VisitConstructor(const Constructor& c) final { - adt_transformer_->VisitType(c->belong_to); - return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, - GradCell_Header + c->name_hint); - } - - ~LazyGradientInitializer() { - // destructors - delete grad_cell_wrapper_; - delete grad_cell_unwrapper_; - delete adt_transformer_; - } - private: // Module IRModule module_; - // pass single instance of ADTTransform to save state of ADTs transformed - ADTTransform* adt_transformer_; - // pass single instance of ADTTransform to save state of ADTs wrapped - GradCellWrapper* grad_cell_wrapper_; - // pass single instance of ADTTransform to save state of ADTs unwrapped - GradCellUnWrapper* grad_cell_unwrapper_; - // var map used for transforming a Clause - std::unordered_map var_map_; - // handle of function -> handle of transformed function - std::unordered_map func_map_; /*! * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type */ @@ -891,35 +283,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P Expr CallPrimitiveOp(const CallNode* call_node) { const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; - - // unwrap arguments + // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back( - grad_cell_unwrapper_->VisitExpr(VisitExpr(expr), VisitType(expr->checked_type()))); + args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } // result of operation return Call(call_node->op, args, call_node->attrs); } }; -IRModule LazyGradientInit(const IRModule& m) { - LazyGradientInitializer lgi = LazyGradientInitializer(m); - std::vector gvs; - for (const auto& p : m->functions) { - gvs.push_back(p.first); - } - for (const auto& gv : gvs) { - m->AddUnchecked(gv, lgi.VisitGlobalVar(gv)); - } - m->Check(); - return m; +Expr LazyGradientInit(const Expr& e, IRModule mod) { + return LazyGradientInitializer(mod).Transform(e); } namespace transform { Pass LazyGradientInit() { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { return relay::LazyGradientInit(m); }; - return CreateModulePass(pass_func, 1, "LazyGradientInit", {}); + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LazyGradientInit(f, m)); + }; + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 7f1ca3ca6719b..414926802870a 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -21,7 +21,6 @@ from tvm.relay import create_executor, transform from tvm.relay.testing import rand, run_infer_type from tvm.testing import assert_allclose -from tvm.relay.prelude import Prelude import pytest def test_tc(): @@ -81,6 +80,7 @@ def test_add_tuple(): mod["main"] = y mod = transform.LazyGradientInit()(mod) + mod = tvm.transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], tensor_type) @@ -391,41 +391,5 @@ def test_ones_like(): y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) -def test_list_adt(): - """test prelude functions on list ADT. which is a recursive ADT""" - mod = tvm.IRModule() - p = Prelude(mod) - - cons = p.cons - nil = p.nil - - mod = transform.LazyGradientInit()(mod) - - ex = create_executor(mod=mod) - - def to_list_adt(list): - l = nil() - for x in list: - l = cons(relay.const(x), l) - return ex.evaluate(l) - - def from_list_adt(list): - l = [] - def rec(x): - if x.constructor.tag == cons.tag: - l.insert(0, x.fields[0].asnumpy().tolist()) - rec(x.fields[1]) - rec(list) - return l - - # test sum - x = np.random.randint(1,101,10) - assert sum(x) == ex.evaluate(mod['sum'])(to_list_adt(x)).asnumpy() - - # test reverse - x = np.random.rand(10) - actual = from_list_adt(ex.evaluate(mod['rev'])(to_list_adt(x))) - assert_allclose(x[::-1], actual) - if __name__ == "__main__": pytest.main([__file__])