Skip to content

Commit

Permalink
Arm64/Sve: Rewrite how ConditionalSelect wraps the embedded mask oper…
Browse files Browse the repository at this point in the history
…ations (#104248)

* All current embedded scenarios work

* Add test coverage for _SveUnaryOpTestTemplate

* wip: review comments

* fix the assert for GatherPrefetch

* review commetns

* jit format

* Add coverage for _SveBinaryOpTestTemplate.template

* review comments

* Add test coverage in _SveTernOpTestTemplate.template

* Add for _SveUnaryOpDifferentRetTypeTestTemplate

* Add _SveBinaryMaskOpTestTemplate.template

* minor feedback

* coverge for _SveBinaryOpDifferentTypesTestTemplate

* Add for _SveBinaryOpDifferentTypesTestTemplate.template

* Add for _SveImmBinaryOpTestTemplate.template

* Add for _SveImmTernOpFirstArgTestTemplate.template

* Add for _SveImmTernOpTestTemplate.template

* Add for _SveImmUnaryOpTestTemplate.template

* Add for _SveTernOpFirstArgTestTemplate.template

* Add for _SveTernOpMaskedOpTestTemplate.template

* Add for SveLoadNonFaultingUnOpTest.template

* Add for SveGatherVectorVectorBases.template

* Add for SveGatherVectorIndices.template

* Add for SveGatherVectorByteOffsets.template

* fix the typos

* minor test feedback

* jit format

* revert fix for GatherPrefetch*

* missed a change
  • Loading branch information
kunalspathak committed Jul 3, 2024
1 parent 7875486 commit a91700f
Show file tree
Hide file tree
Showing 20 changed files with 1,797 additions and 403 deletions.
26 changes: 26 additions & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,7 @@ struct GenTree
inline bool IsVectorAllBitsSet() const;
inline bool IsVectorBroadcast(var_types simdBaseType) const;
inline bool IsMaskAllBitsSet() const;
inline bool IsMaskZero() const;
inline bool IsVectorConst();

inline uint64_t GetIntegralVectorConstElement(size_t index, var_types simdBaseType);
Expand Down Expand Up @@ -9680,6 +9681,31 @@ inline bool GenTree::IsMaskAllBitsSet() const
return false;
}

inline bool GenTree::IsMaskZero() const
{
#ifdef TARGET_ARM64
static_assert_no_msg(AreContiguous(NI_Sve_CreateFalseMaskByte, NI_Sve_CreateFalseMaskDouble,
NI_Sve_CreateFalseMaskInt16, NI_Sve_CreateFalseMaskInt32,
NI_Sve_CreateFalseMaskInt64, NI_Sve_CreateFalseMaskSByte,
NI_Sve_CreateFalseMaskSingle, NI_Sve_CreateFalseMaskUInt16,
NI_Sve_CreateFalseMaskUInt32, NI_Sve_CreateFalseMaskUInt64));

if (OperIsHWIntrinsic())
{
NamedIntrinsic id = AsHWIntrinsic()->GetHWIntrinsicId();
if (id == NI_Sve_ConvertMaskToVector)
{
GenTree* op1 = AsHWIntrinsic()->Op(1);
assert(op1->OperIsHWIntrinsic());
id = op1->AsHWIntrinsic()->GetHWIntrinsicId();
}
return ((id >= NI_Sve_CreateFalseMaskByte) && (id <= NI_Sve_CreateFalseMaskUInt64));
}

#endif
return false;
}

//-------------------------------------------------------------------
// IsVectorConst: returns true if this node is a HWIntrinsic that represents a constant.
//
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,11 @@ class Lowering final : public Phase
GenTree* LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cmpOp);
GenTree* LowerHWIntrinsicCreate(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicDot(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicCndSel(GenTreeHWIntrinsic* node);
#if defined(TARGET_XARCH)
void LowerFusedMultiplyAdd(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicToScalar(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicCndSel(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicTernaryLogic(GenTreeHWIntrinsic* node);
GenTree* LowerHWIntrinsicWithElement(GenTreeHWIntrinsic* node);
GenTree* TryLowerAndOpToResetLowestSetBit(GenTreeOp* andNode);
Expand Down
116 changes: 93 additions & 23 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,37 +1312,45 @@ GenTree* Lowering::LowerHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_AdvSimd_FusedMultiplyAddScalar:
LowerHWIntrinsicFusedMultiplyAddScalar(node);
break;

case NI_Sve_ConditionalSelect:
return LowerHWIntrinsicCndSel(node);
default:
break;
}

if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsicId))
{
LIR::Use use;
if (BlockRange().TryGetUse(node, &use))
bool foundUse = BlockRange().TryGetUse(node, &use);
JITDUMP("lowering EmbeddedMasked HWIntrinisic (before):\n");
DISPTREERANGE(BlockRange(), node);
JITDUMP("\n");

CorInfoType simdBaseJitType = node->GetSimdBaseJitType();
unsigned simdSize = node->GetSimdSize();
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
GenTree* trueMask = comp->gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
GenTree* falseVal = comp->gtNewZeroConNode(simdType);

BlockRange().InsertBefore(node, trueMask);
BlockRange().InsertBefore(node, falseVal);

GenTreeHWIntrinsic* condSelNode =
comp->gtNewSimdHWIntrinsicNode(simdType, trueMask, node, falseVal, NI_Sve_ConditionalSelect,
simdBaseJitType, simdSize);
BlockRange().InsertAfter(node, condSelNode);
if (foundUse)
{
GenTree* user = use.User();
// Wrap the intrinsic in ConditionalSelect only if it is not already inside another ConditionalSelect
if (!user->OperIsHWIntrinsic() || (user->AsHWIntrinsic()->GetHWIntrinsicId() != NI_Sve_ConditionalSelect))
{
CorInfoType simdBaseJitType = node->GetSimdBaseJitType();
unsigned simdSize = node->GetSimdSize();
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
GenTree* trueMask = comp->gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
GenTree* trueVal = node;
GenTree* falseVal = comp->gtNewZeroConNode(simdType);

GenTreeHWIntrinsic* condSelNode =
comp->gtNewSimdHWIntrinsicNode(simdType, trueMask, trueVal, falseVal, NI_Sve_ConditionalSelect,
simdBaseJitType, simdSize);

BlockRange().InsertBefore(node, trueMask);
BlockRange().InsertBefore(node, falseVal);
BlockRange().InsertAfter(node, condSelNode);
use.ReplaceWith(condSelNode);
}
use.ReplaceWith(condSelNode);
}
else
{
condSelNode->SetUnusedValue();
}

JITDUMP("lowering EmbeddedMasked HWIntrinisic (after):\n");
DISPTREERANGE(BlockRange(), condSelNode);
JITDUMP("\n");
}

ContainCheckHWIntrinsic(node);
Expand Down Expand Up @@ -3369,7 +3377,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
}

// Handle op2
if (op2->OperIsHWIntrinsic())
if (op2->OperIsHWIntrinsic() && !op2->IsEmbMaskOp())
{
const GenTreeHWIntrinsic* embOp = op2->AsHWIntrinsic();

Expand Down Expand Up @@ -3492,6 +3500,68 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
}
}
}
//----------------------------------------------------------------------------------------------
// Lowering::LowerHWIntrinsicCndSel: Lowers a Sve ConditionalSelect call
//
// Arguments:
// node - The hardware intrinsic node of the form
// ConditionalSelect(mask, trueValue, falseValue)
//
GenTree* Lowering::LowerHWIntrinsicCndSel(GenTreeHWIntrinsic* cndSelNode)
{
assert(cndSelNode->OperIsHWIntrinsic(NI_Sve_ConditionalSelect));

GenTree* op1 = cndSelNode->Op(1);
GenTree* op2 = cndSelNode->Op(2);
GenTree* op3 = cndSelNode->Op(3);
GenTree* lowerCndSel = cndSelNode;

if (op2->OperIsHWIntrinsic(NI_Sve_ConditionalSelect))
{
// Handle cases where there is a nested ConditionalSelect for
// `trueValue`
GenTreeHWIntrinsic* nestedCndSel = op2->AsHWIntrinsic();
GenTree* nestedOp1 = nestedCndSel->Op(1);
assert(varTypeIsMask(nestedOp1));

if (nestedOp1->IsMaskAllBitsSet())
{
GenTree* nestedOp2 = nestedCndSel->Op(2);
GenTree* nestedOp3 = nestedCndSel->Op(3);

JITDUMP("lowering ConditionalSelect HWIntrinisic (before):\n");
DISPTREERANGE(BlockRange(), cndSelNode);
JITDUMP("\n");

// Transform:
//
// CndSel(mask, CndSel(AllTrue, embeddedMask(trueValOp2), trueValOp3), op3) to
// CndSel(mask, embedded(trueValOp2), op3)
//
cndSelNode->Op(2) = nestedCndSel->Op(2);
if (nestedOp3->IsMaskZero())
{
BlockRange().Remove(nestedOp3);
}
else
{
nestedOp3->SetUnusedValue();
}

BlockRange().Remove(nestedOp1);
BlockRange().Remove(nestedCndSel);

JITDUMP("lowering ConditionalSelect HWIntrinisic (after):\n");
DISPTREERANGE(BlockRange(), cndSelNode);
JITDUMP("\n");

return cndSelNode;
}
}

ContainCheckHWIntrinsic(cndSelNode);
return cndSelNode->gtNext;
}
#endif // FEATURE_HW_INTRINSICS

#endif // TARGET_ARMARCH
Loading

0 comments on commit a91700f

Please sign in to comment.