diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td index a1afc3b8042c28..9a9c09d3c20d61 100644 --- a/llvm/include/llvm/Target/GenericOpcodes.td +++ b/llvm/include/llvm/Target/GenericOpcodes.td @@ -17,6 +17,10 @@ class GenericInstruction : StandardPseudoInstruction { let isPreISelOpcode = true; + + // When all variadic ops share a type with another operand, + // this is the type they share. Used by MIR patterns type inference. + TypedOperand variadicOpsType = ?; } // Provide a variant of an instruction with the same operands, but @@ -1228,6 +1232,7 @@ def G_UNMERGE_VALUES : GenericInstruction { let OutOperandList = (outs type0:$dst0, variable_ops); let InOperandList = (ins type1:$src); let hasSideEffects = false; + let variadicOpsType = type0; } // Insert a smaller register into a larger one at the specified bit-index. @@ -1245,6 +1250,7 @@ def G_MERGE_VALUES : GenericInstruction { let OutOperandList = (outs type0:$dst); let InOperandList = (ins type1:$src0, variable_ops); let hasSideEffects = false; + let variadicOpsType = type1; } /// Create a vector from multiple scalar registers. No implicit @@ -1254,6 +1260,7 @@ def G_BUILD_VECTOR : GenericInstruction { let OutOperandList = (outs type0:$dst); let InOperandList = (ins type1:$src0, variable_ops); let hasSideEffects = false; + let variadicOpsType = type1; } /// Like G_BUILD_VECTOR, but truncates the larger operand types to fit the diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td index 63c485a5a6c607..ee0209eb9e5593 100644 --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -796,7 +796,7 @@ def trunc_shift: GICombineRule < def mul_by_neg_one: GICombineRule < (defs root:$dst), (match (G_MUL $dst, $x, -1)), - (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x)) + (apply (G_SUB $dst, 0, $x)) >; // Fold (xor (and x, y), y) -> (and (not x), y) diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td index 48a06474da78a1..4f705589b92e90 100644 --- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td +++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td @@ -151,7 +151,7 @@ def bad_imm_too_many_args : GICombineRule< (match (COPY $x, (i32 0, 0)):$d), (apply (COPY $x, $b):$d)>; -// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)', 'COPY' is not a ValueType +// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)': unknown type 'COPY' // CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(COPY ?:$x, (COPY 0)) def bad_imm_not_a_valuetype : GICombineRule< (defs root:$a), @@ -186,7 +186,7 @@ def expected_op_name : GICombineRule< (match (G_FNEG $x, i32)), (apply (COPY $x, (i32 0)))>; -// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: invalid operand type: 'not_a_type' is not a ValueType +// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: cannot parse operand type: unknown type 'not_a_type' // CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_FNEG ?:$x, not_a_type:$y)' def not_a_type; def bad_mo_type_not_a_valuetype : GICombineRule< diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td new file mode 100644 index 00000000000000..c9ffe4e7adb3db --- /dev/null +++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td @@ -0,0 +1,68 @@ +// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \ +// RUN: -gicombiner-debug-typeinfer -combiners=MyCombiner %s 2>&1 | \ +// RUN: FileCheck %s + +// Checks reasoning of the inference rules. + +include "llvm/Target/Target.td" +include "llvm/Target/GlobalISel/Combine.td" + +def MyTargetISA : InstrInfo; +def MyTarget : Target { let InstructionSet = MyTargetISA; } + +// CHECK: Rule Operand Type Equivalence Classes for inference_mul_by_neg_one: +// CHECK-NEXT: Groups for __inference_mul_by_neg_one_match_0: [dst, x] +// CHECK-NEXT: Groups for __inference_mul_by_neg_one_apply_0: [dst, x] +// CHECK-NEXT: Final Type Equivalence Classes: [dst, x] +// CHECK-NEXT: INFER: imm 0 -> GITypeOf<$x> +// CHECK-NEXT: Apply patterns for rule inference_mul_by_neg_one after inference: +// CHECK-NEXT: (CodeGenInstructionPattern name:__inference_mul_by_neg_one_apply_0 G_SUB operands:[$dst, (GITypeOf<$x> 0), $x]) +def inference_mul_by_neg_one: GICombineRule < + (defs root:$dst), + (match (G_MUL $dst, $x, -1)), + (apply (G_SUB $dst, 0, $x)) +>; + +// CHECK: Rule Operand Type Equivalence Classes for infer_complex_tempreg: +// CHECK-NEXT: Groups for __infer_complex_tempreg_match_0: [dst] [x, y, z] +// CHECK-NEXT: Groups for __infer_complex_tempreg_apply_0: [tmp2] [x, y] +// CHECK-NEXT: Groups for __infer_complex_tempreg_apply_1: [tmp, tmp2] +// CHECK-NEXT: Groups for __infer_complex_tempreg_apply_2: [dst, tmp] +// CHECK-NEXT: Final Type Equivalence Classes: [dst, tmp, tmp2] [x, y, z] +// CHECK-NEXT: INFER: MachineOperand $tmp2 -> GITypeOf<$dst> +// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst> +// CHECK-NEXT: Apply patterns for rule infer_complex_tempreg after inference: +// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_0 G_BUILD_VECTOR operands:[GITypeOf<$dst>:$tmp2, $x, $y]) +// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_1 G_FNEG operands:[GITypeOf<$dst>:$tmp, GITypeOf<$dst>:$tmp2]) +// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_2 G_FNEG operands:[$dst, GITypeOf<$dst>:$tmp]) +def infer_complex_tempreg: GICombineRule < + (defs root:$dst), + (match (G_MERGE_VALUES $dst, $x, $y, $z)), + (apply (G_BUILD_VECTOR $tmp2, $x, $y), + (G_FNEG $tmp, $tmp2), + (G_FNEG $dst, $tmp)) +>; + +// CHECK: Rule Operand Type Equivalence Classes for infer_variadic_outs: +// CHECK-NEXT: Groups for __infer_variadic_outs_match_0: [x, y] [vec] +// CHECK-NEXT: Groups for __infer_variadic_outs_match_1: [dst, x] +// CHECK-NEXT: Groups for __infer_variadic_outs_apply_0: [tmp, y] +// CHECK-NEXT: Groups for __infer_variadic_outs_apply_1: +// CHECK-NEXT: Final Type Equivalence Classes: [tmp, dst, x, y] [vec] +// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst> +// CHECK-NEXT: Apply patterns for rule infer_variadic_outs after inference: +// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_0 G_FNEG operands:[GITypeOf<$dst>:$tmp, $y]) +// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_1 COPY operands:[$dst, GITypeOf<$dst>:$tmp]) +def infer_variadic_outs: GICombineRule < + (defs root:$dst), + (match (G_UNMERGE_VALUES $x, $y, $vec), + (G_FNEG $dst, $x)), + (apply (G_FNEG $tmp, $y), + (COPY $dst, $tmp)) +>; + +def MyCombiner: GICombiner<"GenMyCombiner", [ + inference_mul_by_neg_one, + infer_complex_tempreg, + infer_variadic_outs +]>; diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td index 6040d6def44976..4e182d555db336 100644 --- a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td +++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td @@ -8,7 +8,8 @@ include "llvm/Target/GlobalISel/Combine.td" def MyTargetISA : InstrInfo; def MyTarget : Target { let InstructionSet = MyTargetISA; } -// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name +// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(anonymous_7029 0)': invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name +// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ANYEXT ?:$dst, (anonymous_ def NoDollarSign : GICombineRule< (defs root:$dst), (match (G_ZEXT $dst, $src)), @@ -47,7 +48,9 @@ def InferredUseInMatch : GICombineRule< (match (G_ZEXT $dst, $src)), (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>; -// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0' +// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: conflicting types for operand 'src': 'i32' vs 'GITypeOf<$dst>' +// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: 'src' seen with type 'GITypeOf<$dst>' in '__InferenceConflict_apply_0' +// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: 'src' seen with type 'i32' in '__InferenceConflict_match_0' def InferenceConflict : GICombineRule< (defs root:$dst), (match (G_ZEXT $dst, i32:$src)), diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp index 0c7b33a7b9d889..b4b3db70c076a1 100644 --- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp @@ -35,6 +35,7 @@ #include "GlobalISelMatchTableExecutorEmitter.h" #include "SubtargetFeatureInfo.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Statistic.h" @@ -68,6 +69,9 @@ cl::opt DebugCXXPreds( "gicombiner-debug-cxxpreds", cl::desc("Add Contextual/Debug comments to all C++ predicates"), cl::cat(GICombinerEmitterCat)); +cl::opt DebugTypeInfer("gicombiner-debug-typeinfer", + cl::desc("Print type inference debug logs"), + cl::cat(GICombinerEmitterCat)); constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply"; constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_"; @@ -298,23 +302,30 @@ CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode; /// - Special types, e.g. GITypeOf class PatternType { public: - PatternType() = default; - PatternType(const Record *R) : R(R) {} + enum PTKind : uint8_t { + PT_None, - bool isValidType() const { return !R || isLLT() || isSpecial(); } + PT_ValueType, + PT_TypeOf, + }; + + PatternType() : Kind(PT_None), Data() {} + + static std::optional get(ArrayRef DiagLoc, + const Record *R, Twine DiagCtx); + static PatternType getTypeOf(StringRef OpName); - bool isLLT() const { return R && R->isSubClassOf("ValueType"); } - bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); } - bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); } + bool isNone() const { return Kind == PT_None; } + bool isLLT() const { return Kind == PT_ValueType; } + bool isSpecial() const { return isTypeOf(); } + bool isTypeOf() const { return Kind == PT_TypeOf; } StringRef getTypeOfOpName() const; LLTCodeGen getLLTCodeGen() const; - bool checkSemantics(ArrayRef DiagLoc) const; - LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const; - explicit operator bool() const { return R != nullptr; } + explicit operator bool() const { return !isNone(); } bool operator==(const PatternType &Other) const; bool operator!=(const PatternType &Other) const { return !operator==(Other); } @@ -322,26 +333,66 @@ class PatternType { std::string str() const; private: - StringRef getRawOpName() const { return R->getValueAsString("OpName"); } + PatternType(PTKind Kind) : Kind(Kind), Data() {} + + PTKind Kind; + union DataT { + DataT() : Str() {} + + /// PT_ValueType -> ValueType Def. + const Record *Def; - const Record *R = nullptr; + /// PT_TypeOf -> Operand name (without the '$') + StringRef Str; + } Data; }; +std::optional PatternType::get(ArrayRef DiagLoc, + const Record *R, Twine DiagCtx) { + assert(R); + if (R->isSubClassOf("ValueType")) { + PatternType PT(PT_ValueType); + PT.Data.Def = R; + return PT; + } + + if (R->isSubClassOf(TypeOfClassName)) { + auto RawOpName = R->getValueAsString("OpName"); + if (!RawOpName.starts_with("$")) { + PrintError(DiagLoc, DiagCtx + ": invalid operand name format '" + + RawOpName + "' in " + TypeOfClassName + + ": expected '$' followed by an operand name"); + return std::nullopt; + } + + PatternType PT(PT_TypeOf); + PT.Data.Str = RawOpName.drop_front(1); + return PT; + } + + PrintError(DiagLoc, DiagCtx + ": unknown type '" + R->getName() + "'"); + return std::nullopt; +} + +PatternType PatternType::getTypeOf(StringRef OpName) { + PatternType PT(PT_TypeOf); + PT.Data.Str = OpName; + return PT; +} + StringRef PatternType::getTypeOfOpName() const { assert(isTypeOf()); - StringRef Name = getRawOpName(); - Name.consume_front("$"); - return Name; + return Data.Str; } LLTCodeGen PatternType::getLLTCodeGen() const { assert(isLLT()); - return *MVTToLLT(getValueType(R)); + return *MVTToLLT(getValueType(Data.Def)); } LLTCodeGenOrTempType PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const { - assert(isValidType()); + assert(!isNone()); if (isLLT()) return getLLTCodeGen(); @@ -351,50 +402,31 @@ PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const { return OM.getTempTypeIdx(RM); } -bool PatternType::checkSemantics(ArrayRef DiagLoc) const { - if (!isTypeOf()) - return true; - - auto RawOpName = getRawOpName(); - if (RawOpName.starts_with("$")) - return true; - - PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " + - TypeOfClassName + - ": expected '$' followed by an operand name"); - return false; -} - bool PatternType::operator==(const PatternType &Other) const { - if (R == Other.R) { - if (R && R->getName() != Other.R->getName()) { - dbgs() << "Same ptr but: " << R->getName() << " and " - << Other.R->getName() << "?\n"; - assert(false); - } + if (Kind != Other.Kind) + return false; + + switch (Kind) { + case PT_None: return true; + case PT_ValueType: + return Data.Def == Other.Data.Def; + case PT_TypeOf: + return Data.Str == Other.Data.Str; } - if (isTypeOf() && Other.isTypeOf()) - return getTypeOfOpName() == Other.getTypeOfOpName(); - - return false; + llvm_unreachable("Unknown Type Kind"); } std::string PatternType::str() const { - if (!R) + switch (Kind) { + case PT_None: return ""; - - if (!isValidType()) - return ""; - - if (isLLT()) - return R->getName().str(); - - assert(isSpecial()); - - if (isTypeOf()) + case PT_ValueType: + return Data.Def->getName().str(); + case PT_TypeOf: return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str(); + } llvm_unreachable("Unknown type!"); } @@ -607,14 +639,10 @@ class InstructionOperand { using IntImmTy = int64_t; InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type) - : Value(Imm), Name(insertStrRef(Name)), Type(Type) { - assert(Type.isValidType()); - } + : Value(Imm), Name(insertStrRef(Name)), Type(Type) {} InstructionOperand(StringRef Name, PatternType Type) - : Name(insertStrRef(Name)), Type(Type) { - assert(Type.isValidType()); - } + : Name(insertStrRef(Name)), Type(Type) {} bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); } @@ -638,7 +666,6 @@ class InstructionOperand { void setType(PatternType NewType) { assert((!Type || (Type == NewType)) && "Overwriting type!"); - assert(NewType.isValidType()); Type = NewType; } PatternType getType() const { return Type; } @@ -809,12 +836,10 @@ void InstructionPattern::print(raw_ostream &OS, bool PrintName) const { /// Maps InstructionPattern operands to their definitions. This allows us to tie /// different patterns of a (apply), (match) or (patterns) set of patterns /// together. -template class OperandTable { +class OperandTable { public: - static_assert(std::is_base_of_v, - "DefTy should be a derived class from InstructionPattern"); - - bool addPattern(DefTy *P, function_ref DiagnoseRedef) { + bool addPattern(InstructionPattern *P, + function_ref DiagnoseRedef) { for (const auto &Op : P->named_operands()) { StringRef OpName = Op.getOperandName(); @@ -843,10 +868,10 @@ template class OperandTable { struct LookupResult { LookupResult() = default; - LookupResult(DefTy *Def) : Found(true), Def(Def) {} + LookupResult(InstructionPattern *Def) : Found(true), Def(Def) {} bool Found = false; - DefTy *Def = nullptr; + InstructionPattern *Def = nullptr; bool isLiveIn() const { return Found && !Def; } }; @@ -857,7 +882,9 @@ template class OperandTable { return LookupResult(); } - DefTy *getDef(StringRef OpName) const { return lookup(OpName).Def; } + InstructionPattern *getDef(StringRef OpName) const { + return lookup(OpName).Def; + } void print(raw_ostream &OS, StringRef Name = "", StringRef Indent = "") const { @@ -872,11 +899,11 @@ template class OperandTable { SmallVector Keys(Table.keys()); sort(Keys); - OS << "\n"; + OS << '\n'; for (const auto &Key : Keys) { const auto *Def = Table.at(Key); OS << Indent << " " << Key << " -> " - << (Def ? Def->getName() : "") << "\n"; + << (Def ? Def->getName() : "") << '\n'; } OS << Indent << ")\n"; } @@ -887,7 +914,7 @@ template class OperandTable { void dump() const { print(dbgs()); } private: - StringMap Table; + StringMap Table; }; //===- CodeGenInstructionPattern ------------------------------------------===// @@ -956,62 +983,76 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const { /// /// It infers the type of each operand, check it's consistent with the known /// type of the operand, and then sets all of the types in all operands in -/// setAllOperandTypes. +/// propagateTypes. /// /// It also handles verifying correctness of special types. class OperandTypeChecker { public: OperandTypeChecker(ArrayRef DiagLoc) : DiagLoc(DiagLoc) {} - bool check(InstructionPattern *P, + /// Step 1: Check each pattern one by one. All patterns that pass through here + /// are added to a common worklist so propagateTypes can access them. + bool check(InstructionPattern &P, std::function VerifyTypeOfOperand); - void setAllOperandTypes(); + /// Step 2: Propagate all types. e.g. if one use of "$a" has type i32, make + /// all uses of "$a" have type i32. + void propagateTypes(); + +protected: + ArrayRef DiagLoc; private: + using InconsistentTypeDiagFn = std::function; + + void PrintSeenWithTypeIn(InstructionPattern &P, StringRef OpName, + PatternType Ty) const { + PrintNote(DiagLoc, "'" + OpName + "' seen with type '" + Ty.str() + + "' in '" + P.getName() + "'"); + } + struct OpTypeInfo { PatternType Type; - InstructionPattern *TypeSrc = nullptr; + InconsistentTypeDiagFn PrintTypeSrcNote = []() {}; }; - ArrayRef DiagLoc; StringMap Types; SmallVector Pats; }; bool OperandTypeChecker::check( - InstructionPattern *P, + InstructionPattern &P, std::function VerifyTypeOfOperand) { - Pats.push_back(P); + Pats.push_back(&P); - for (auto &Op : P->operands()) { + for (auto &Op : P.operands()) { const auto Ty = Op.getType(); if (!Ty) continue; - if (!Ty.checkSemantics(DiagLoc)) - return false; - if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty)) return false; if (!Op.isNamedOperand()) continue; - auto &Info = Types[Op.getOperandName()]; + StringRef OpName = Op.getOperandName(); + auto &Info = Types[OpName]; if (!Info.Type) { Info.Type = Ty; - Info.TypeSrc = P; + Info.PrintTypeSrcNote = [this, OpName, Ty, &P]() { + PrintSeenWithTypeIn(P, OpName, Ty); + }; continue; } if (Info.Type != Ty) { PrintError(DiagLoc, "conflicting types for operand '" + - Op.getOperandName() + "': first seen with '" + - Info.Type.str() + "' in '" + - Info.TypeSrc->getName() + ", now seen with '" + - Ty.str() + "' in '" + P->getName() + "'"); + Op.getOperandName() + "': '" + Info.Type.str() + + "' vs '" + Ty.str() + "'"); + PrintSeenWithTypeIn(P, OpName, Ty); + Info.PrintTypeSrcNote(); return false; } } @@ -1019,7 +1060,7 @@ bool OperandTypeChecker::check( return true; } -void OperandTypeChecker::setAllOperandTypes() { +void OperandTypeChecker::propagateTypes() { for (auto *Pat : Pats) { for (auto &Op : Pat->named_operands()) { if (auto &Info = Types[Op.getOperandName()]; Info.Type) @@ -1073,7 +1114,7 @@ class PatFrag { /// Each argument to the `pattern` DAG operator is parsed into a Pattern /// instance. struct Alternative { - OperandTable<> OpTable; + OperandTable OpTable; SmallVector, 4> Pats; }; @@ -1297,11 +1338,11 @@ bool PatFrag::checkSemantics() { OperandTypeChecker OTC(Def.getLoc()); for (auto &Pat : Alt.Pats) { if (auto *IP = dyn_cast(Pat.get())) { - if (!OTC.check(IP, CheckTypeOf)) + if (!OTC.check(*IP, CheckTypeOf)) return false; } } - OTC.setAllOperandTypes(); + OTC.propagateTypes(); } return true; @@ -1369,7 +1410,7 @@ bool PatFrag::buildOperandsTables() { } void PatFrag::print(raw_ostream &OS, StringRef Indent) const { - OS << Indent << "(PatFrag name:" << getName() << "\n"; + OS << Indent << "(PatFrag name:" << getName() << '\n'; if (!in_params().empty()) { OS << Indent << " (ins "; printParamsList(OS, in_params()); @@ -1613,7 +1654,7 @@ class PrettyStackTraceParse : public PrettyStackTraceEntry { OS << "Parsing " << PatFragClassName << " '" << Def.getName() << "'"; else OS << "Parsing '" << Def.getName() << "'"; - OS << "\n"; + OS << '\n'; } }; @@ -1635,10 +1676,429 @@ class PrettyStackTraceEmit : public PrettyStackTraceEntry { if (Pat) OS << " [" << Pat->getKindName() << " '" << Pat->getName() << "']"; - OS << "\n"; + OS << '\n'; } }; +//===- CombineRuleOperandTypeChecker --------------------------------------===// + +/// This is a wrapper around OperandTypeChecker specialized for Combiner Rules. +/// On top of doing the same things as OperandTypeChecker, this also attempts to +/// infer as many types as possible for temporary register defs & immediates in +/// apply patterns. +/// +/// The inference is trivial and leverages the MCOI OperandTypes encoded in +/// CodeGenInstructions to infer types across patterns in a CombineRule. It's +/// thus very limited and only supports CodeGenInstructions (but that's the main +/// use case so it's fine). +/// +/// We only try to infer untyped operands in apply patterns when they're temp +/// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is +/// a named operand from a match pattern. +class CombineRuleOperandTypeChecker : private OperandTypeChecker { +public: + CombineRuleOperandTypeChecker(const Record &RuleDef, + const OperandTable &MatchOpTable) + : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef), + MatchOpTable(MatchOpTable) {} + + /// Records and checks a 'match' pattern. + bool processMatchPattern(InstructionPattern &P); + + /// Records and checks an 'apply' pattern. + bool processApplyPattern(InstructionPattern &P); + + /// Propagates types, then perform type inference and do a second round of + /// propagation in the apply patterns only if any types were inferred. + void propagateAndInferTypes(); + +private: + /// TypeEquivalenceClasses are groups of operands of an instruction that share + /// a common type. + /// + /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and + /// d have the same type too. b/c and a/d don't have to have the same type, + /// though. + using TypeEquivalenceClasses = EquivalenceClasses; + + /// \returns true for `OPERAND_GENERIC_` 0 through 5. + /// These are the MCOI types that can be registers. The other MCOI types are + /// either immediates, or fancier operands used only post-ISel, so we don't + /// care about them for combiners. + static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) { + // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI + // OperandTypes are either never used in gMIR, or not relevant (e.g. + // OPERAND_GENERIC_IMM, which is definitely never a register). + return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_"); + } + + /// Finds the "MCOI::"" operand types for each operand of \p CGP. + /// + /// This is a bit trickier than it looks because we need to handle variadic + /// in/outs. + /// + /// e.g. for + /// (G_BUILD_VECTOR $vec, $x, $y) -> + /// [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1, + /// MCOI::OPERAND_GENERIC_1] + /// + /// For unknown types (which can happen in variadics where varargs types are + /// inconsistent), a unique name is given, e.g. "unknown_type_0". + static std::vector + getMCOIOperandTypes(const CodeGenInstructionPattern &CGP); + + /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs. + void getInstEqClasses(const InstructionPattern &P, + TypeEquivalenceClasses &OutTECs) const; + + /// Calls `getInstEqClasses` on all patterns of the rule to produce the whole + /// rule's TypeEquivalenceClasses. + TypeEquivalenceClasses getRuleEqClasses() const; + + /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p + /// TECs. + /// + /// This is achieved by trying to find a named operand in \p IP that shares + /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that + /// operand instead. + /// + /// \returns the inferred type or an empty PatternType if inference didn't + /// succeed. + PatternType inferImmediateType(const InstructionPattern &IP, + unsigned ImmOpIdx, + const TypeEquivalenceClasses &TECs) const; + + /// Looks inside \p TECs to infer \p OpName's type. + /// + /// \returns the inferred type or an empty PatternType if inference didn't + /// succeed. + PatternType inferNamedOperandType(const InstructionPattern &IP, + StringRef OpName, + const TypeEquivalenceClasses &TECs) const; + + const Record &RuleDef; + SmallVector MatchPats; + SmallVector ApplyPats; + + const OperandTable &MatchOpTable; +}; + +bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) { + MatchPats.push_back(&P); + return check(P, /*CheckTypeOf*/ [](const auto &) { + // GITypeOf in 'match' is currently always rejected by the + // CombineRuleBuilder after inference is done. + return true; + }); +} + +bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) { + ApplyPats.push_back(&P); + return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) { + // GITypeOf<"$x"> can only be used if "$x" is a matched operand. + const auto OpName = Ty.getTypeOfOpName(); + if (MatchOpTable.lookup(OpName).Found) + return true; + + PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() + + "') does not refer to a matched operand!"); + return false; + }); +} + +void CombineRuleOperandTypeChecker::propagateAndInferTypes() { + /// First step here is to propagate types using the OperandTypeChecker. That + /// way we ensure all uses of a given register have consistent types. + propagateTypes(); + + /// Build the TypeEquivalenceClasses for the whole rule. + const TypeEquivalenceClasses TECs = getRuleEqClasses(); + + /// Look at the apply patterns and find operands that need to be + /// inferred. We then try to find an equivalence class that they're a part of + /// and select the best operand to use for the `GITypeOf` type. We prioritize + /// defs of matched instructions because those are guaranteed to be registers. + bool InferredAny = false; + for (auto *Pat : ApplyPats) { + for (unsigned K = 0; K < Pat->operands_size(); ++K) { + auto &Op = Pat->getOperand(K); + + // We only want to take a look at untyped defs or immediates. + if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType()) + continue; + + // Infer defs & named immediates. + if (Op.isDef() || Op.isNamedImmediate()) { + // Check it's not a redefinition of a matched operand. + // In such cases, inference is not necessary because we just copy + // operands and don't create temporary registers. + if (MatchOpTable.lookup(Op.getOperandName()).Found) + continue; + + // Inference is needed here, so try to do it. + if (PatternType Ty = + inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) { + if (DebugTypeInfer) + errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n'; + Op.setType(Ty); + InferredAny = true; + } + + continue; + } + + // Infer immediates + if (Op.hasImmValue()) { + if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) { + if (DebugTypeInfer) + errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n'; + Op.setType(Ty); + InferredAny = true; + } + continue; + } + } + } + + // If we've inferred any types, we want to propagate them across the apply + // patterns. Type inference only adds GITypeOf types that point to Matched + // operands, so we definitely don't want to propagate types into the match + // patterns as well, otherwise bad things happen. + if (InferredAny) { + OperandTypeChecker OTC(RuleDef.getLoc()); + for (auto *Pat : ApplyPats) { + if (!OTC.check(*Pat, [&](const auto &) { return true; })) + PrintFatalError(RuleDef.getLoc(), + "OperandTypeChecker unexpectedly failed on '" + + Pat->getName() + "' during Type Inference"); + } + OTC.propagateTypes(); + + if (DebugTypeInfer) { + errs() << "Apply patterns for rule " << RuleDef.getName() + << " after inference:\n"; + for (auto *Pat : ApplyPats) { + errs() << " "; + Pat->print(errs(), /*PrintName*/ true); + errs() << '\n'; + } + errs() << '\n'; + } + } +} + +PatternType CombineRuleOperandTypeChecker::inferImmediateType( + const InstructionPattern &IP, unsigned ImmOpIdx, + const TypeEquivalenceClasses &TECs) const { + // We can only infer CGPs. + const auto *CGP = dyn_cast(&IP); + if (!CGP) + return {}; + + // For CGPs, we try to infer immediates by trying to infer another named + // operand that shares its type. + // + // e.g. + // Pattern: G_BUILD_VECTOR $x, $y, 0 + // MCOIs: [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1, + // MCOI::OPERAND_GENERIC_1] + // $y has the same type as 0, so we can infer $y and get the type 0 should + // have. + + // We infer immediates by looking for a named operand that shares the same + // MCOI type. + const auto MCOITypes = getMCOIOperandTypes(*CGP); + StringRef ImmOpTy = MCOITypes[ImmOpIdx]; + + for (const auto &[Idx, Ty] : enumerate(MCOITypes)) { + if (Idx != ImmOpIdx && Ty == ImmOpTy) { + const auto &Op = IP.getOperand(Idx); + if (!Op.isNamedOperand()) + continue; + + // Named operand with the same name, try to infer that. + if (PatternType InferTy = + inferNamedOperandType(IP, Op.getOperandName(), TECs)) + return InferTy; + } + } + + return {}; +} + +PatternType CombineRuleOperandTypeChecker::inferNamedOperandType( + const InstructionPattern &IP, StringRef OpName, + const TypeEquivalenceClasses &TECs) const { + // This is the simplest possible case, we just need to find a TEC that + // contains OpName. Look at all other operands in equivalence class and try to + // find a suitable one. + + // Check for a def of a matched pattern. This is guaranteed to always + // be a register so we can blindly use that. + StringRef GoodOpName; + for (auto It = TECs.findLeader(OpName); It != TECs.member_end(); ++It) { + if (*It == OpName) + continue; + + const auto LookupRes = MatchOpTable.lookup(*It); + if (LookupRes.Def) // Favor defs + return PatternType::getTypeOf(*It); + + // Otherwise just save this in case we don't find any def. + if (GoodOpName.empty() && LookupRes.Found) + GoodOpName = *It; + } + + if (!GoodOpName.empty()) + return PatternType::getTypeOf(GoodOpName); + + // No good operand found, give up. + return {}; +} + +std::vector CombineRuleOperandTypeChecker::getMCOIOperandTypes( + const CodeGenInstructionPattern &CGP) { + // FIXME?: Should we cache this? We call it twice when inferring immediates. + + static unsigned UnknownTypeIdx = 0; + + std::vector OpTypes; + auto &CGI = CGP.getInst(); + Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction") + ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType") + : nullptr; + std::string VarArgsTyName = + VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str() + : ("unknown_type_" + Twine(UnknownTypeIdx++)).str(); + + // First, handle defs. + for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K) + OpTypes.push_back(CGI.Operands[K].OperandType); + + // Then, handle variadic defs if there are any. + if (CGP.hasVariadicDefs()) { + for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K) + OpTypes.push_back(VarArgsTyName); + } + + // If we had variadic defs, the op idx in the pattern won't match the op idx + // in the CGI anymore. + int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs(); + assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0)); + + // Handle all remaining use operands, including variadic ones. + for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) { + unsigned CGIOpIdx = K + CGIOpOffset; + if (CGIOpIdx >= CGI.Operands.size()) { + assert(CGP.isVariadic()); + OpTypes.push_back(VarArgsTyName); + } else { + OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType); + } + } + + assert(OpTypes.size() == CGP.operands_size()); + return OpTypes; +} + +void CombineRuleOperandTypeChecker::getInstEqClasses( + const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const { + // Determine the TypeEquivalenceClasses by: + // - Getting the MCOI Operand Types. + // - Creating a Map of MCOI Type -> [Operand Indexes] + // - Iterating over the map, filtering types we don't like, and just adding + // the array of Operand Indexes to \p OutTECs. + + // We can only do this on CodeGenInstructions. Other InstructionPatterns have + // no type inference information associated with them. + // TODO: Could we add some inference information to builtins at least? e.g. + // ReplaceReg should always replace with a reg of the same type, for instance. + // Though, those patterns are often used alone so it might not be worth the + // trouble to infer their types. + auto *CGP = dyn_cast(&P); + if (!CGP) + return; + + const auto MCOITypes = getMCOIOperandTypes(*CGP); + assert(MCOITypes.size() == P.operands_size()); + + DenseMap> TyToOpIdx; + for (const auto &[Idx, Ty] : enumerate(MCOITypes)) + TyToOpIdx[Ty].push_back(Idx); + + if (DebugTypeInfer) + errs() << "\tGroups for " << P.getName() << ":\t"; + + for (const auto &[Ty, Idxs] : TyToOpIdx) { + if (!canMCOIOperandTypeBeARegister(Ty)) + continue; + + if (DebugTypeInfer) + errs() << '['; + StringRef Sep = ""; + + // We only collect named operands. + StringRef Leader; + for (unsigned Idx : Idxs) { + const auto &Op = P.getOperand(Idx); + if (!Op.isNamedOperand()) + continue; + + const auto OpName = Op.getOperandName(); + if (DebugTypeInfer) { + errs() << Sep << OpName; + Sep = ", "; + } + + if (Leader.empty()) + OutTECs.insert((Leader = OpName)); + else + OutTECs.unionSets(Leader, OpName); + } + + if (DebugTypeInfer) + errs() << "] "; + } + + if (DebugTypeInfer) + errs() << '\n'; +} + +CombineRuleOperandTypeChecker::TypeEquivalenceClasses +CombineRuleOperandTypeChecker::getRuleEqClasses() const { + StringMap OpNameToEqClassIdx; + TypeEquivalenceClasses TECs; + + if (DebugTypeInfer) + errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName() + << ":\n"; + + for (const auto *Pat : MatchPats) + getInstEqClasses(*Pat, TECs); + for (const auto *Pat : ApplyPats) + getInstEqClasses(*Pat, TECs); + + if (DebugTypeInfer) { + errs() << "Final Type Equivalence Classes: "; + for (auto ClassIt = TECs.begin(); ClassIt != TECs.end(); ++ClassIt) { + // only print non-empty classes. + if (auto MembIt = TECs.member_begin(ClassIt); + MembIt != TECs.member_end()) { + errs() << '['; + StringRef Sep = ""; + for (; MembIt != TECs.member_end(); ++MembIt) { + errs() << Sep << *MembIt; + Sep = ", "; + } + errs() << "] "; + } + } + errs() << '\n'; + } + + return TECs; +} + //===- CombineRuleBuilder -------------------------------------------------===// /// Parses combine rule and builds a small intermediate representation to tie @@ -1820,8 +2280,8 @@ class CombineRuleBuilder { PatternMap ApplyPats; /// Operand tables to tie match/apply patterns together. - OperandTable<> MatchOpTable; - OperandTable<> ApplyOpTable; + OperandTable MatchOpTable; + OperandTable ApplyOpTable; /// Set by findRoots. Pattern *MatchRoot = nullptr; @@ -1893,14 +2353,14 @@ bool CombineRuleBuilder::emitRuleMatchers() { void CombineRuleBuilder::print(raw_ostream &OS) const { OS << "(CombineRule name:" << RuleDef.getName() << " id:" << RuleID - << " root:" << RootName << "\n"; + << " root:" << RootName << '\n'; if (!MatchDatas.empty()) { OS << " (MatchDatas\n"; for (const auto &MD : MatchDatas) { OS << " "; MD.print(OS); - OS << "\n"; + OS << '\n'; } OS << " )\n"; } @@ -1909,7 +2369,7 @@ void CombineRuleBuilder::print(raw_ostream &OS) const { OS << " (PatFrags\n"; for (const auto *PF : SeenPatFrags) { PF->print(OS, /*Indent=*/" "); - OS << "\n"; + OS << '\n'; } OS << " )\n"; } @@ -1921,7 +2381,7 @@ void CombineRuleBuilder::print(raw_ostream &OS) const { return; } - OS << "\n"; + OS << '\n'; for (const auto &[Name, Pat] : Pats) { OS << " "; if (Pat.get() == MatchRoot) @@ -1931,7 +2391,7 @@ void CombineRuleBuilder::print(raw_ostream &OS) const { OS << ""; OS << Name << ":"; Pat->print(OS, /*PrintName=*/false); - OS << "\n"; + OS << '\n'; } OS << " )\n"; }; @@ -1972,9 +2432,9 @@ void CombineRuleBuilder::verify() const { // Both strings are allocated in the pool using insertStrRef. if (Name.data() != Pat->getName().data()) { dbgs() << "Map StringRef: '" << Name << "' @ " - << (const void *)Name.data() << "\n"; + << (const void *)Name.data() << '\n'; dbgs() << "Pat String: '" << Pat->getName() << "' @ " - << (const void *)Pat->getName().data() << "\n"; + << (const void *)Pat->getName().data() << '\n'; PrintFatalError("StringRef stored in the PatternMap is not referencing " "the same string as its Pattern!"); } @@ -2074,7 +2534,7 @@ void CombineRuleBuilder::addCXXPredicate(RuleMatcher &M, P.expandCode(CE, RuleDef.getLoc(), [&](raw_ostream &OS) { OS << "// Pattern Alternatives: "; print(OS, Alts); - OS << "\n"; + OS << '\n'; }); IM->addPredicate( ExpandedCode.getEnumNameWithPrefix(CXXPredPrefix)); @@ -2102,39 +2562,23 @@ bool CombineRuleBuilder::hasEraseRoot() const { } bool CombineRuleBuilder::typecheckPatterns() { - OperandTypeChecker OTC(RuleDef.getLoc()); - - const auto CheckMatchTypeOf = [&](const PatternType &) -> bool { - // We'll reject those after we're done inferring - return true; - }; + CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable); for (auto &Pat : values(MatchPats)) { if (auto *IP = dyn_cast(Pat.get())) { - if (!OTC.check(IP, CheckMatchTypeOf)) + if (!OTC.processMatchPattern(*IP)) return false; } } - const auto CheckApplyTypeOf = [&](const PatternType &Ty) { - // GITypeOf<"$x"> can only be used if "$x" is a matched operand. - const auto OpName = Ty.getTypeOfOpName(); - if (MatchOpTable.lookup(OpName).Found) - return true; - - PrintError("'" + OpName + "' ('" + Ty.str() + - "') does not refer to a matched operand!"); - return false; - }; - for (auto &Pat : values(ApplyPats)) { if (auto *IP = dyn_cast(Pat.get())) { - if (!OTC.check(IP, CheckApplyTypeOf)) + if (!OTC.processApplyPattern(*IP)) return false; } } - OTC.setAllOperandTypes(); + OTC.propagateAndInferTypes(); // Always check this after in case inference adds some special types to the // match patterns. @@ -2630,7 +3074,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand( // untyped immediate, e.g. 0 if (const auto *IntImm = dyn_cast(OpInit)) { std::string Name = OpName ? OpName->getAsUnquotedString() : ""; - IP.addOperand(IntImm->getValue(), Name, /*Type=*/nullptr); + IP.addOperand(IntImm->getValue(), Name, PatternType()); return true; } @@ -2640,13 +3084,11 @@ bool CombineRuleBuilder::parseInstructionPatternOperand( return ParseErr(); const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc()); - PatternType ImmTy(TyDef); - if (!ImmTy.isValidType()) { - PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() + - "', '" + TyDef->getName() + "' is not a ValueType or " + - SpecialTyClassName); + auto ImmTy = PatternType::get(RuleDef.getLoc(), TyDef, + "cannot parse immediate '" + + DagOp->getAsUnquotedString() + "'"); + if (!ImmTy) return false; - } if (!IP.hasAllDefs()) { PrintError("out operand of '" + IP.getInstName() + @@ -2659,7 +3101,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand( return ParseErr(); std::string Name = OpName ? OpName->getAsUnquotedString() : ""; - IP.addOperand(Val->getValue(), Name, ImmTy); + IP.addOperand(Val->getValue(), Name, *ImmTy); return true; } @@ -2671,20 +3113,18 @@ bool CombineRuleBuilder::parseInstructionPatternOperand( return false; } const Record *Def = DefI->getDef(); - PatternType Ty(Def); - if (!Ty.isValidType()) { - PrintError("invalid operand type: '" + Def->getName() + - "' is not a ValueType"); + auto Ty = + PatternType::get(RuleDef.getLoc(), Def, "cannot parse operand type"); + if (!Ty) return false; - } - IP.addOperand(OpName->getAsUnquotedString(), Ty); + IP.addOperand(OpName->getAsUnquotedString(), *Ty); return true; } // Untyped operand e.g. $x/$z in (G_FNEG $x, $z) if (isa(OpInit)) { assert(OpName && "Unset w/ no OpName?"); - IP.addOperand(OpName->getAsUnquotedString(), /*Type=*/nullptr); + IP.addOperand(OpName->getAsUnquotedString(), PatternType()); return true; } @@ -3746,7 +4186,7 @@ void GICombinerEmitter::emitRunCustomAction(raw_ostream &OS) { OS << " switch(ApplyID) {\n"; for (const auto &Apply : ApplyCode) { OS << " case " << Apply->getEnumNameWithPrefix(CXXApplyPrefix) << ":{\n" - << " " << join(split(Apply->Code, "\n"), "\n ") << "\n" + << " " << join(split(Apply->Code, '\n'), "\n ") << '\n' << " return;\n"; OS << " }\n"; }