diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 079174f67e939b..2f3a2a7b2f573b 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -1770,6 +1770,7 @@ struct GenTree inline bool IsVectorZero() const; inline bool IsVectorCreate() const; inline bool IsVectorAllBitsSet() const; + inline bool IsMaskAllBitsSet() const; inline bool IsVectorConst(); inline uint64_t GetIntegralVectorConstElement(size_t index, var_types simdBaseType); @@ -9238,6 +9239,32 @@ inline bool GenTree::IsVectorAllBitsSet() const return false; } +inline bool GenTree::IsMaskAllBitsSet() const +{ +#ifdef TARGET_ARM64 + static_assert_no_msg(AreContiguous(NI_Sve_CreateTrueMaskByte, NI_Sve_CreateTrueMaskDouble, + NI_Sve_CreateTrueMaskInt16, NI_Sve_CreateTrueMaskInt32, + NI_Sve_CreateTrueMaskInt64, NI_Sve_CreateTrueMaskSByte, + NI_Sve_CreateTrueMaskSingle, NI_Sve_CreateTrueMaskUInt16, + NI_Sve_CreateTrueMaskUInt32, NI_Sve_CreateTrueMaskUInt64)); + + if (OperIsHWIntrinsic()) + { + NamedIntrinsic id = AsHWIntrinsic()->GetHWIntrinsicId(); + if (id == NI_Sve_ConvertMaskToVector) + { + GenTree* op1 = AsHWIntrinsic()->Op(1); + assert(op1->OperIsHWIntrinsic()); + id = op1->AsHWIntrinsic()->GetHWIntrinsicId(); + } + return ((id == NI_Sve_CreateTrueMaskAll) || + ((id >= NI_Sve_CreateTrueMaskByte) && (id <= NI_Sve_CreateTrueMaskUInt64))); + } + +#endif + return false; +} + //------------------------------------------------------------------- // IsVectorConst: returns true if this node is a HWIntrinsic that represents a constant. // diff --git a/src/coreclr/jit/hwintrinsic.cpp b/src/coreclr/jit/hwintrinsic.cpp index 96060b2beacb7d..3187c3e5c2e044 100644 --- a/src/coreclr/jit/hwintrinsic.cpp +++ b/src/coreclr/jit/hwintrinsic.cpp @@ -1622,7 +1622,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic, GenTree* op1 = retNode->AsHWIntrinsic()->Op(1); if (intrinsic == NI_Sve_ConditionalSelect) { - if (op1->IsVectorAllBitsSet()) + if (op1->IsVectorAllBitsSet() || op1->IsMaskAllBitsSet()) { return retNode->AsHWIntrinsic()->Op(2); } diff --git a/src/coreclr/jit/hwintrinsiccodegenarm64.cpp b/src/coreclr/jit/hwintrinsiccodegenarm64.cpp index edabd7030359f7..f4df29874dbbb3 100644 --- a/src/coreclr/jit/hwintrinsiccodegenarm64.cpp +++ b/src/coreclr/jit/hwintrinsiccodegenarm64.cpp @@ -406,8 +406,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) // Handle case where op2 is operation that needs embedded mask GenTree* op2 = intrin.op2; assert(intrin.id == NI_Sve_ConditionalSelect); - assert(op2->isContained()); assert(op2->OperIsHWIntrinsic()); + assert(op2->isContained()); // Get the registers and intrinsics that needs embedded mask const HWIntrinsic intrinEmbMask(op2->AsHWIntrinsic()); @@ -439,10 +439,54 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) { case 1: assert(!instrIsRMW); + if (targetReg != falseReg) { - GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg); + // If targetReg is not the same as `falseReg` then need to move + // the `falseReg` to `targetReg`. + + if (intrin.op3->isContained()) + { + assert(intrin.op3->IsVectorZero()); + if (intrin.op1->isContained()) + { + // We already skip importing ConditionalSelect if op1 == trueAll, however + // if we still see it here, it is because we wrapped the predicated instruction + // inside ConditionalSelect. + // As such, no need to move the `falseReg` to `targetReg` + // because the predicated instruction will eventually set it. + assert(intrin.op1->IsMaskAllBitsSet()); + } + else + { + // If falseValue is zero, just zero out those lanes of targetReg using `movprfx` + // and /Z + GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg, targetReg, + opt); + } + } + else if (targetReg == embMaskOp1Reg) + { + // target != falseValue, but we do not want to overwrite target with `embMaskOp1Reg`. + // We will first do the predicate operation and then do conditionalSelect inactive + // elements from falseValue + + // We cannot use use `movprfx` here to move falseReg to targetReg because that will + // overwrite the value of embMaskOp1Reg which is present in targetReg. + GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt); + + GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg, + falseReg, opt, INS_SCALABLE_OPTS_UNPREDICATED); + break; + } + else + { + // At this point, target != embMaskOp1Reg != falseReg, so just go ahead + // and move the falseReg unpredicated into targetReg. + GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg); + } } + GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt); break;