diff --git a/frontend/include/chpl/framework/ErrorBase.h b/frontend/include/chpl/framework/ErrorBase.h index c94274bc46e8..fb34206b6628 100644 --- a/frontend/include/chpl/framework/ErrorBase.h +++ b/frontend/include/chpl/framework/ErrorBase.h @@ -257,7 +257,7 @@ class GeneralError : public BasicError { return owned(new Error##NAME__(*this));\ }\ \ - ErrorInfo info() const { return info_; }\ + const ErrorInfo& info() const { return info_; }\ }; #include "chpl/framework/error-classes-list.h" #undef DIAGNOSTIC_CLASS diff --git a/frontend/include/chpl/framework/all-global-strings.h b/frontend/include/chpl/framework/all-global-strings.h index 03e74a966a99..e2505c282c14 100644 --- a/frontend/include/chpl/framework/all-global-strings.h +++ b/frontend/include/chpl/framework/all-global-strings.h @@ -43,9 +43,12 @@ X(c_ptrConst , "c_ptrConst") X(c_char , "c_char") X(class_ , "class") X(deinit , "deinit") +X(deserialize , "deserialize") X(dmapped , "dmapped") X(domain , "domain") X(false_ , "false") +X(follower , "follower") +X(followThis , "followThis") X(for_ , "for") X(forall , "forall") X(foreach , "foreach") @@ -58,6 +61,7 @@ X(init , "init") X(initequals , "init=") X(int_ , "int") X(isCoercible , "isCoercible") +X(leader , "leader") X(locale , "locale") X(main , "main") X(max , "max") @@ -79,15 +83,16 @@ X(reduceAssign , "reduce=") X(RootClass , "RootClass") X(scan , "scan") X(serialize , "serialize") -X(deserialize , "deserialize") X(shared , "shared") X(single , "single") X(sparse , "sparse") X(stable , "stable") +X(standalone , "standalone") X(string , "string") X(subdomain , "subdomain") X(super_ , "super") X(sync , "sync") +X(tag , "tag") X(this_ , "this") X(these_ , "these") X(true_ , "true") diff --git a/frontend/include/chpl/resolution/resolution-types.h b/frontend/include/chpl/resolution/resolution-types.h index bb8c366a040e..ea60d1aa84ed 100644 --- a/frontend/include/chpl/resolution/resolution-types.h +++ b/frontend/include/chpl/resolution/resolution-types.h @@ -23,6 +23,7 @@ #include "chpl/framework/UniqueString.h" #include "chpl/resolution/scope-types.h" #include "chpl/types/CompositeType.h" +#include "chpl/types/EnumType.h" #include "chpl/types/QualifiedType.h" #include "chpl/types/Type.h" #include "chpl/uast/AstNode.h" @@ -309,6 +310,11 @@ class UntypedFnSignature { return isMethod_; } + /** Returns true if this is an iterator */ + bool isIterator() const { + return kind_ == uast::Function::ITER; + } + /** Returns true if this function throws */ bool throws() const { return throws_; @@ -949,6 +955,11 @@ class TypedFnSignature { const TypedFnSignature* parentFn, Bitmap formalsInstantiated); + /** If this is an iterator, set 'found' to a string representing its + 'iterKind', or "" if it is a serial iterator. Returns 'true' only + if this is an iterator and a valid 'iterKind' formal was found. */ + bool fetchIterKindStr(Context* context, UniqueString& outIterKindStr) const; + public: /** Get the unique TypedFnSignature containing these components */ static @@ -1094,6 +1105,38 @@ class TypedFnSignature { return formalTypes_[i]; } + bool isMethod() const { + return untypedSignature_->isMethod(); + } + + bool isIterator() const { + return untypedSignature_->isIterator(); + } + + /** Returns 'true' if this signature is for a standalone parallel iterator. */ + bool isParallelStandaloneIterator(Context* context) const { + UniqueString str; + return fetchIterKindStr(context, str) && str == USTR("standalone"); + } + + /** Returns 'true' if this signature is for a parallel leader iterator. */ + bool isParallelLeaderIterator(Context* context) const { + UniqueString str; + return fetchIterKindStr(context, str) && str == USTR("leader"); + } + + /** Returns 'true' if this signature is for a parallel follower iterator. */ + bool isParallelFollowerIterator(Context* context) const { + UniqueString str; + return fetchIterKindStr(context, str) && str == USTR("follower"); + } + + /** Returns 'true' if this signature is for a serial iterator. */ + bool isSerialIterator(Context* context) const { + UniqueString str; + return fetchIterKindStr(context, str) && str.isEmpty(); + } + /// \cond DO_NOT_DOCUMENT DECLARE_DUMP; /// \endcond DO_NOT_DOCUMENT diff --git a/frontend/include/chpl/types/EnumType.h b/frontend/include/chpl/types/EnumType.h index 1ba62b1ec2a1..967b8b53ec6d 100644 --- a/frontend/include/chpl/types/EnumType.h +++ b/frontend/include/chpl/types/EnumType.h @@ -21,6 +21,7 @@ #define CHPL_TYPES_ENUM_TYPE_H #include "chpl/types/Type.h" +#include "chpl/types/QualifiedType.h" namespace chpl{ namespace types { @@ -64,6 +65,19 @@ class EnumType final : public Type { /** Get the type for a range's boundKind */ static const EnumType* getBoundKindType(Context* context); + /** Get the type representing an iterator's "iteration kind". */ + static const EnumType* getIterKindType(Context* context); + + /** Given an enum type 'et', get a map from the name of each constant + in 'et' to each constant represented as a param value. + If there are multiple enum constants with the same name (which + means the AST is semantically incorrect), then only the first + constant is added to the map. Returns 'nullptr' if 'et' is + 'nullptr' or has an empty ID, or if it does not have any AST + representing it. */ + static const std::map* + getParamConstantsMapOrNull(Context* context, const EnumType* et); + ~EnumType() = default; virtual void stringify(std::ostream& ss, diff --git a/frontend/include/chpl/types/QualifiedType.h b/frontend/include/chpl/types/QualifiedType.h index 31f799a6a960..9bd09114ea88 100644 --- a/frontend/include/chpl/types/QualifiedType.h +++ b/frontend/include/chpl/types/QualifiedType.h @@ -158,6 +158,10 @@ class QualifiedType final { return isUnknown() || (genericity() != Type::CONCRETE); } + bool isUnknownOrErroneous() const { + return isUnknown() || isErroneousType(); + } + /** Returns true if kind is TYPE */ bool isType() const { return kind_ == Kind::TYPE; } diff --git a/frontend/lib/resolution/Resolver.cpp b/frontend/lib/resolution/Resolver.cpp index 3a958efd2173..cf2abec88dc4 100644 --- a/frontend/lib/resolution/Resolver.cpp +++ b/frontend/lib/resolution/Resolver.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include #include #include #include @@ -49,6 +50,70 @@ namespace resolution { using namespace uast; using namespace types; +namespace { + struct IterDetails { + // Iterators will be resolved for each bit set in descending order. + // Resolution stops when the first iterator with a valid yield type + // is found. TODO: Stop the moment we find a signature instead? + // + // If the 'LEADER_FOLLOWER' bit is set, then the 'FOLLOWER' bit will + // be ignored, and the 'leaderYieldType' will be computed from the + // resolved leader signature. Any provided value for 'leaderYieldType' + // will be overwritten. + enum Priority { + NONE = 0b0000, + STANDALONE = 0b0001, + LEADER_FOLLOWER = 0b0010, + FOLLOWER = 0b0100, + SERIAL = 0b1000, + }; + + // When an iter is resolved these pieces of the process will be stored. + struct Pieces { + bool wasCallInjected = false; + CallResolutionResult crr; + const TypedFnSignature* sig = crr.mostSpecific().only().fn(); + }; + Pieces standalone; + Pieces leader; + Pieces follower; + Pieces serial; + + // This is the iterator that resolution stopped at if it succeeded. + // In that case, the type in 'idxType' will be set to the yield type + // of this iterator. + Priority succeededAt = NONE; + + // If the 'LEADER_FOLLOWER' bit was set, then this will always be the + // yield type of the resolved leader iterator. Otherwise, it will be + // the type the user provided. + QualifiedType leaderYieldType; + + // This will be set to the yield type of the first iterator to succeed. + QualifiedType idxType; + }; +} + +static const QualifiedType& +getIterKindConstantOrUnknownQuery(Context* context, UniqueString constant); + +// Helper to resolve a specified iterator signature and its yield type. +static QualifiedType +resolveIterTypeWithTag(Resolver& rv, + IterDetails::Pieces& outIterPieces, + const AstNode* astForErr, + const AstNode* iterand, + UniqueString iterKindStr, + const QualifiedType& followThisFormal); + +// Resolve iterators according to the policy set in 'mask' (see the type +// 'IterDetails::Policy'). Resolution stops the moment an iterator is +// found with a usable yield type. +static IterDetails resolveIterDetails(Resolver& rv, + const AstNode* astForErr, + const AstNode* iterand, + const QualifiedType& leaderYieldType, + int mask); static QualifiedType::Kind qualifiedTypeKindForId(Context* context, ID id) { if (parsing::idIsParenlessFunction(context, id)) @@ -1641,7 +1706,6 @@ void Resolver::handleResolvedCall(ResolvedExpression& r, const CallInfo& ci, const CallResolutionResult& c, optional actionAndId) { - if (handleResolvedCallWithoutError(r, astForErr, ci, c, std::move(actionAndId))) { issueErrorForFailedCallResolution(astForErr, ci, c); } @@ -1654,7 +1718,6 @@ void Resolver::handleResolvedCallPrintCandidates(ResolvedExpression& r, const QualifiedType& receiverType, const CallResolutionResult& c, optional actionAndId) { - bool wasCallGenerated = (bool) actionAndId; CHPL_ASSERT(!wasCallGenerated || receiverType.isUnknown()); if (handleResolvedCallWithoutError(r, call, ci, c, std::move(actionAndId))) { @@ -4107,174 +4170,468 @@ void Resolver::exit(const New* node) { } } -static QualifiedType resolveSerialIterType(Resolver& resolver, - const AstNode* astForErr, - const AstNode* iterand) { - Context* context = resolver.context; - iterand->traverse(resolver); - ResolvedExpression& iterandRE = resolver.byPostorder.byAst(iterand); +static const QualifiedType& +getIterKindConstantOrUnknownQuery(Context* context, UniqueString constant) { + QUERY_BEGIN(getIterKindConstantOrUnknownQuery, context, constant); - if (resolver.scopeResolveOnly) { - return QualifiedType(QualifiedType::UNKNOWN, - UnknownType::get(context)); + QualifiedType ret = { QualifiedType::UNKNOWN, UnknownType::get(context) }; + + if (!constant.isEmpty()) { + auto ik = EnumType::getIterKindType(context); + if (auto m = EnumType::getParamConstantsMapOrNull(context, ik)) { + auto it = m->find(constant); + if (it != m->end()) ret = it->second; + } } - auto& MSC = iterandRE.mostSpecific(); - bool isIter = MSC.isEmpty() == false && - MSC.numBest() == 1 && - MSC.only().fn()->untyped()->kind() == uast::Function::Kind::ITER; + return QUERY_END(ret); +} + +// This helper resolves by priority order as described in 'IterDetails'. +static IterDetails +resolveIterDetailsInPriorityOrder(Resolver& rv, + bool& outWasIterSigResolved, + const AstNode* astForErr, + const AstNode* iterand, + const QualifiedType& leaderYieldType, + int mask) { + IterDetails ret; + bool computedLeaderYieldType = false; + if (mask & IterDetails::STANDALONE) { + ret.idxType = resolveIterTypeWithTag(rv, ret.standalone, astForErr, + iterand, USTR("standalone"), {}); + outWasIterSigResolved = (ret.standalone.sig != nullptr); + if (!ret.idxType.isUnknownOrErroneous()) { + ret.succeededAt = IterDetails::STANDALONE; + return ret; + } + } - bool wasResolved = iterandRE.type().isUnknown() == false && - iterandRE.type().isErroneousType() == false; + if (mask & IterDetails::LEADER_FOLLOWER) { + ret.leaderYieldType = resolveIterTypeWithTag(rv, ret.leader, astForErr, + iterand, USTR("leader"), + {}); + computedLeaderYieldType = true; + } else if (mask & IterDetails::FOLLOWER) { + ret.leaderYieldType = leaderYieldType; + } + + if (mask & IterDetails::LEADER_FOLLOWER || + mask & IterDetails::FOLLOWER) { + if (!ret.leaderYieldType.isUnknownOrErroneous()) { + ret.idxType = resolveIterTypeWithTag(rv, ret.follower, astForErr, + iterand, USTR("follower"), + ret.leaderYieldType); + outWasIterSigResolved = (ret.follower.sig != nullptr); + if (!ret.idxType.isUnknownOrErroneous()) { + ret.succeededAt = computedLeaderYieldType + ? IterDetails::LEADER_FOLLOWER + : IterDetails::FOLLOWER; + return ret; + } + } + } - QualifiedType idxType; + if (mask & IterDetails::SERIAL) { + ret.idxType = resolveIterTypeWithTag(rv, ret.serial, astForErr, + iterand, {}, {}); + outWasIterSigResolved = (ret.serial.sig != nullptr); + if (!ret.idxType.isUnknownOrErroneous()) { + ret.succeededAt = IterDetails::SERIAL; + } + } - if (isIter) { - idxType = iterandRE.type(); - } else if (wasResolved) { - // - // Resolve "iterand.these()" - // - std::vector actuals; - actuals.push_back(CallInfoActual(iterandRE.type(), USTR("this"))); - auto ci = CallInfo (/* name */ USTR("these"), - /* calledType */ iterandRE.type(), - /* isMethodCall */ true, - /* hasQuestionArg */ false, - /* isParenless */ false, - actuals); - auto inScope = resolver.scopeStack.back(); - auto inScopes = CallScopeInfo::forNormalCall(inScope, resolver.poiScope); - auto c = resolveGeneratedCall(context, iterand, ci, inScopes); + return ret; +} - if (c.mostSpecific().only()) { - idxType = c.exprType(); - resolver.handleResolvedCall(iterandRE, astForErr, ci, c, - { { AssociatedAction::ITERATE, iterand->id() } }); +static IterDetails resolveIterDetails(Resolver& rv, + const AstNode* astForErr, + const AstNode* iterand, + const QualifiedType& leaderYieldType, + int mask) { + Context* context = rv.context; + + if (mask == IterDetails::NONE || rv.scopeResolveOnly) { + // Resolve the iterand as much as possible even if there is nothing to do. + iterand->traverse(rv); + return {}; + } + + // Resolve the iterand but suppress errors for now. We'll reissue them + // next, possibly suppressing a "NoMatchingCandidates" for the iterand if + // our injected call is successful. + auto runResult = context->runAndTrackErrors([&](Context* context) { + iterand->traverse(rv); + return nullptr; + }); + + // Resolve iterators, stopping immediately when we get a valid yield type. + bool wasIterSigResolved = false; + auto ret = resolveIterDetailsInPriorityOrder(rv, wasIterSigResolved, + astForErr, iterand, + leaderYieldType, + mask); + + // Only issue a "not iterable" error if the iterand has a type. If it was + // not typed then earlier resolution of the iterand will have spit out an + // approriate error for us already. + bool skipNoCandidatesError = true; + if (!wasIterSigResolved) { + auto& iterandRE = rv.byPostorder.byAst(iterand); + if (!iterandRE.type().isUnknownOrErroneous()) { + ret.idxType = CHPL_TYPE_ERROR(context, NonIterable, astForErr, iterand, + iterandRE.type()); } else { - idxType = CHPL_TYPE_ERROR(context, NonIterable, astForErr, iterand, iterandRE.type()); + skipNoCandidatesError = false; + } + } + + // Reissue the errors. + for (auto& e : runResult.errors()) { + if (e->type() == NoMatchingCandidates) { + auto nmc = static_cast(e.get()); + auto& f = std::get<0>(nmc->info()); + if (skipNoCandidatesError && f == iterand) continue; } + context->report(std::move(e)); + } + + return ret; +} + +static QualifiedType +resolveIterTypeWithTag(Resolver& rv, + IterDetails::Pieces& outIterPieces, + const AstNode* astForErr, + const AstNode* iterand, + UniqueString iterKindStr, + const QualifiedType& followThisFormal) { + Context* context = rv.context; + QualifiedType unknown(QualifiedType::UNKNOWN, UnknownType::get(context)); + QualifiedType error(QualifiedType::UNKNOWN, ErroneousType::get(context)); + + auto iterKindFormal = getIterKindConstantOrUnknownQuery(context, iterKindStr); + bool needStandalone = iterKindStr == USTR("standalone"); + bool needLeader = iterKindStr == USTR("leader"); + bool needFollower = iterKindStr == USTR("follower"); + bool needSerial = iterKindStr.isEmpty(); + + // Exit early if we need a parallel iterator and don't have the enum. + if (!needSerial && iterKindFormal.isUnknown()) return error; + + auto iterKindType = EnumType::getIterKindType(context); + CHPL_ASSERT(needSerial || (iterKindFormal.type() == iterKindType && + iterKindFormal.hasParamPtr())); + + // Inspect the resolution result to determine what should be done next. + auto& iterandRE = rv.byPostorder.byAst(iterand); + auto& MSC = iterandRE.mostSpecific(); + auto fn = MSC.only() ? MSC.only().fn() : nullptr; + bool wasIterandTypeResolved = !iterandRE.type().isUnknownOrErroneous(); + bool wasIterResolved = fn && fn->isIterator(); + bool wasMatchingIterResolved = wasIterResolved && + ((fn->isParallelStandaloneIterator(context) && needStandalone) || + (fn->isParallelLeaderIterator(context) && needLeader) || + (fn->isParallelFollowerIterator(context) && needFollower) || + (fn->isSerialIterator(context) && needSerial)); + + QualifiedType ret = error; + + // We resolved the iterator we need right off the bat, so use it. + if (wasMatchingIterResolved) { + ret = iterandRE.type(); + outIterPieces = { false, {}, fn }; + + // We cannot inject e.g., the 'tag' actual in this case, so error out. + } else if (needSerial && wasIterResolved) { + return error; + + // There's nothing to do in this case, so error out. + } else if (needSerial && !wasIterandTypeResolved) { + return error; + + // In this branch we prepare a generated iterator call. It could be a call + // to 'these()' on the iterand type, or it could be a redirect of the + // existing call with 'iterKind' and (optionally) 'followThis' arguments + // tacked onto the end. } else { - idxType = QualifiedType(QualifiedType::UNKNOWN, - ErroneousType::get(context)); + bool shouldCreateTheseCall = !wasIterResolved && wasIterandTypeResolved; + + // We need to fill in the following pieces to construct a 'CallInfo'. + UniqueString callName; + types::QualifiedType callCalledType; + bool callIsMethodCall = false; + bool callHasQuestionArg = false; + bool callIsParenless = false; + std::vector callActuals; + + // If we are constructing a new 'these()' call then add a receiver. + if (shouldCreateTheseCall) { + callName = USTR("these"); + callCalledType = iterandRE.type(); + callIsMethodCall = true; + callActuals.push_back(CallInfoActual(iterandRE.type(), USTR("this"))); + + // The iterand is an unresolved call, or it is a resolved iterator but + // not the one that we need. Regather existing actuals and reuse the + // receiver if it is present. + } else { + auto call = iterand->toCall(); + CHPL_ASSERT(call); + + bool raiseErrors = false; + auto tmp = CallInfo::create(context, call, rv.byPostorder, raiseErrors); + + callName = tmp.name(); + callCalledType = tmp.calledType(); + callIsMethodCall = tmp.isMethodCall(); + callIsParenless = tmp.isParenless(); + for (auto& a : tmp.actuals()) callActuals.push_back(a); + } + + if (!needSerial) { + callActuals.push_back(CallInfoActual(iterKindFormal, USTR("tag"))); + } + + if (needFollower) { + auto x = CallInfoActual(followThisFormal, USTR("followThis")); + callActuals.push_back(std::move(x)); + } + + auto ci = CallInfo(std::move(callName), + std::move(callCalledType), + std::move(callIsMethodCall), + std::move(callHasQuestionArg), + std::move(callIsParenless), + std::move(callActuals)); + auto inScope = rv.scopeStack.back(); + auto inScopes = CallScopeInfo::forNormalCall(inScope, rv.poiScope); + auto c = resolveGeneratedCall(context, iterand, ci, inScopes); + + outIterPieces = { true, c, c.mostSpecific().only().fn() }; + ret = c.exprType(); + + if (!ret.isUnknownOrErroneous()) { + rv.handleResolvedCall(iterandRE, astForErr, ci, c, + { { AssociatedAction::ITERATE, iterand->id() } }); + } } - return idxType; + return ret; } -bool Resolver::enter(const IndexableLoop* loop) { +static bool resolveParamForLoop(Resolver& rv, const For* forLoop) { + const AstNode* iterand = forLoop->iterand(); + Context* context = rv.context; - auto forLoop = loop->toFor(); - bool isParamLoop = forLoop != nullptr && forLoop->isParam(); + iterand->traverse(rv); - // whether this is a param or regular loop, before entering its body - // or considering its iterand, resolve expressions in the loop's attribute - // group. - if (auto ag = loop->attributeGroup()) { - ag->traverse(*this); + if (rv.scopeResolveOnly) { + rv.enterScope(forLoop); + return true; } - if (isParamLoop) { - const AstNode* iterand = loop->iterand(); - iterand->traverse(*this); + if (iterand->isRange() == false) { + context->error(forLoop, "param loops may only iterate over range literals"); + } else { + // TODO: ranges with strides, '#', and '<' + const Range* rng = iterand->toRange(); + ResolvedExpression& lowRE = rv.byPostorder.byAst(rng->lowerBound()); + ResolvedExpression& hiRE = rv.byPostorder.byAst(rng->upperBound()); + // TODO: Simplify once we no longer use nullptr for param() + auto lowParam = lowRE.type().param(); + auto hiParam = hiRE.type().param(); + auto low = lowParam ? lowParam->toIntParam() : nullptr; + auto hi = hiParam ? hiParam->toIntParam() : nullptr; + + if (low == nullptr || hi == nullptr) { + context->error(forLoop, "param loops may only iterate over range literals with integer bounds"); + return false; + } - if (scopeResolveOnly) { - enterScope(loop); - return true; + std::vector loopResults; + for (int64_t i = low->value(); i <= hi->value(); i++) { + ResolutionResultByPostorderID bodyResults; + auto cur = Resolver::paramLoopResolver(rv, forLoop, bodyResults); + + cur.enterScope(forLoop); + + ResolvedExpression& idx = cur.byPostorder.byAst(forLoop->index()); + QualifiedType qt = QualifiedType(QualifiedType::PARAM, lowRE.type().type(), IntParam::get(context, i)); + idx.setType(qt); + forLoop->body()->traverse(cur); + + cur.exitScope(forLoop); + + loopResults.push_back(std::move(cur.byPostorder)); } - if (iterand->isRange() == false) { - context->error(loop, "param loops may only iterate over range literals"); - } else { - // TODO: ranges with strides, '#', and '<' - const Range* rng = iterand->toRange(); - ResolvedExpression& lowRE = byPostorder.byAst(rng->lowerBound()); - ResolvedExpression& hiRE = byPostorder.byAst(rng->upperBound()); - // TODO: Simplify once we no longer use nullptr for param() - auto lowParam = lowRE.type().param(); - auto hiParam = hiRE.type().param(); - auto low = lowParam ? lowParam->toIntParam() : nullptr; - auto hi = hiParam ? hiParam->toIntParam() : nullptr; - - if (low == nullptr || hi == nullptr) { - context->error(loop, "param loops may only iterate over range literals with integer bounds"); - return false; - } + auto paramLoop = new ResolvedParamLoop(forLoop); + paramLoop->setLoopBodies(loopResults); + auto& resolvedLoopExpr = rv.byPostorder.byAst(forLoop); + resolvedLoopExpr.setParamLoop(paramLoop); + } - std::vector loopResults; - for (int64_t i = low->value(); i <= hi->value(); i++) { - ResolutionResultByPostorderID bodyResults; - auto cur = Resolver::paramLoopResolver(*this, forLoop, bodyResults); + return false; +} - cur.enterScope(loop); +static void +backpatchArrayTypeSpecifier(Resolver& rv, const IndexableLoop* loop) { + if (rv.scopeResolveOnly || !loop->isBracketLoop()) return; + Context* context = rv.context; - ResolvedExpression& idx = cur.byPostorder.byAst(loop->index()); - QualifiedType qt = QualifiedType(QualifiedType::PARAM, lowRE.type().type(), IntParam::get(context, i)); - idx.setType(qt); - loop->body()->traverse(cur); + // Check if this is an array + auto iterandType = rv.byPostorder.byAst(loop->iterand()).type(); + if (!iterandType.isUnknown() && iterandType.type()->isDomainType()) { + QualifiedType eltType; - cur.exitScope(loop); + CHPL_ASSERT(loop->isExpressionLevel() && loop->numStmts() <= 1); + if (loop->numStmts() == 1) { + eltType = rv.byPostorder.byAst(loop->stmt(0)).type(); + } else if (loop->numStmts() == 0) { + eltType = QualifiedType(QualifiedType::TYPE, AnyType::get(context)); + } - loopResults.push_back(std::move(cur.byPostorder)); - } + // TODO: resolve array types when the iterand is something other than + // a domain. + if (eltType.isType() || eltType.kind() == QualifiedType::TYPE_QUERY) { + eltType = QualifiedType(QualifiedType::TYPE, eltType.type()); + auto arrayType = ArrayType::getArrayType(context, iterandType, eltType); - auto paramLoop = new ResolvedParamLoop(forLoop); - paramLoop->setLoopBodies(loopResults); - auto& resolvedLoopExpr = byPostorder.byAst(loop); - resolvedLoopExpr.setParamLoop(paramLoop); + auto& re = rv.byPostorder.byAst(loop); + re.setType(QualifiedType(QualifiedType::TYPE, arrayType)); } + } +} - return false; - } else { - QualifiedType idxType = resolveSerialIterType(*this, loop, loop->iterand()); +static QualifiedType +resolveZipExpression(Resolver& rv, const IndexableLoop* loop, const Zip* zip) { + Context* context = rv.context; + bool loopRequiresParallel = loop->isForall(); + bool loopPrefersParallel = loopRequiresParallel || loop->isBracketLoop(); - enterScope(loop); + // We build up tuple element types by resolving all the zip actuals. + std::vector eltTypes; - if (const Decl* idx = loop->index()) { - ResolvedExpression& re = byPostorder.byAst(idx); - re.setType(idxType); + // We determine the follower policy by resolving the leader actual. + auto followerPolicy = IterDetails::NONE; + QualifiedType leaderYieldType; + + // Get the leader actual. + if (auto leader = (zip->numActuals() ? zip->actual(0) : nullptr)) { + + // Set the policy mask for the leader based on the loop properties. + int m = IterDetails::NONE; + if (loopPrefersParallel) m |= IterDetails::LEADER_FOLLOWER; + if (!loopRequiresParallel) m |= IterDetails::SERIAL; + CHPL_ASSERT(m != IterDetails::NONE); + + // Resolve the leader iterator. + auto dt = resolveIterDetails(rv, leader, leader, {}, m); + + // Configure what followers should do using the iterator details. + if (dt.succeededAt == IterDetails::LEADER_FOLLOWER) { + followerPolicy = IterDetails::FOLLOWER; + leaderYieldType = dt.leaderYieldType; + eltTypes.push_back(dt.idxType); + } else if (dt.succeededAt == IterDetails::SERIAL) { + followerPolicy = IterDetails::SERIAL; + eltTypes.push_back(dt.idxType); + } else { + return { QualifiedType::UNKNOWN, ErroneousType::get(context) }; } + } - if (auto with = loop->withClause()) { - with->traverse(*this); - } - loop->body()->traverse(*this); + CHPL_ASSERT(followerPolicy != IterDetails::NONE); - if (!scopeResolveOnly && loop->isBracketLoop()) { - // Check if this is an array - auto iterandType = byPostorder.byAst(loop->iterand()).type(); - if (!iterandType.isUnknown() && iterandType.type()->isDomainType()) { - QualifiedType eltType; - if (loop->numStmts() == 1) { - eltType = byPostorder.byAst(loop->stmt(0)).type(); - } else if (loop->numStmts() == 0) { - eltType = QualifiedType(QualifiedType::TYPE, AnyType::get(context)); - } else { - CHPL_ASSERT(false && "array expression with multiple loop body statements?"); - } + // Resolve the follower iterator or serial iterator for all followers. + for (int i = 1; i < zip->numActuals(); i++) { + auto actual = zip->actual(i); + auto dt = resolveIterDetails(rv, actual, actual, leaderYieldType, + followerPolicy); + auto& qt = dt.idxType; + eltTypes.push_back(qt); + } - // TODO: resolve array types when the iterand is something other than - // a domain. - if (eltType.isType() || eltType.kind() == QualifiedType::TYPE_QUERY) { - eltType = QualifiedType(QualifiedType::TYPE, eltType.type()); - auto arrayType = ArrayType::getArrayType(context, iterandType, eltType); + CHPL_ASSERT(((int) eltTypes.size()) == zip->numActuals()); - ResolvedExpression& re = byPostorder.byAst(loop); - re.setType(QualifiedType(QualifiedType::TYPE, arrayType)); - } - } + auto kind = QualifiedType::CONST_VAR; + for (auto& et : eltTypes) { + if (!et.isUnknownOrErroneous() && !et.isConst()) { + kind = QualifiedType::VAR; + break; } } + // This 'TupleType' builder preserves references for index types. + auto type = TupleType::getQualifiedTuple(context, std::move(eltTypes)); + QualifiedType ret = { kind, type }; + + auto& reZip = rv.byPostorder.byAst(zip); + reZip.setType(ret); + + return ret; +} + +bool Resolver::enter(const IndexableLoop* loop) { + auto forLoop = loop->toFor(); + bool isParamForLoop = forLoop != nullptr && forLoop->isParam(); + + // whether this is a param or regular loop, before entering its body + // or considering its iterand, resolve expressions in the loop's attribute + // group. + if (auto ag = loop->attributeGroup()) { + ag->traverse(*this); + } + + if (isParamForLoop) return resolveParamForLoop(*this, loop->toFor()); + + auto iterand = loop->iterand(); + QualifiedType idxType; + + if (iterand->isZip()) { + idxType = resolveZipExpression(*this, loop, iterand->toZip()); + + } else { + bool loopRequiresParallel = loop->isForall(); + bool loopPrefersParallel = loopRequiresParallel || loop->isBracketLoop(); + + int m = IterDetails::NONE; + if (loopPrefersParallel) m |= IterDetails::LEADER_FOLLOWER | + IterDetails::STANDALONE; + if (!loopRequiresParallel) m |= IterDetails::SERIAL; + CHPL_ASSERT(m != IterDetails::NONE); + + auto dt = resolveIterDetails(*this, loop, iterand, {}, m); + idxType = dt.idxType; + } + + enterScope(loop); + + if (const Decl* idx = loop->index()) { + ResolvedExpression& re = byPostorder.byAst(idx); + re.setType(idxType); + } + + if (auto with = loop->withClause()) { + with->traverse(*this); + } + + loop->body()->traverse(*this); + + // TODO: Resolve the loop body first when it looks like an array type, + // and if the body is a type, then skip resolving iterators to save time. + backpatchArrayTypeSpecifier(*this, loop); + return false; } void Resolver::exit(const IndexableLoop* loop) { // Param loops handle scope differently auto forLoop = loop->toFor(); - bool isParamLoop = forLoop != nullptr && forLoop->isParam(); + bool isParamForLoop = forLoop != nullptr && forLoop->isParam(); - if (isParamLoop == false || scopeResolveOnly) { + if (isParamForLoop == false || scopeResolveOnly) { exitScope(loop); } } @@ -4479,9 +4836,11 @@ static QualifiedType resolveReduceScanOp(Resolver& resolver, const AstNode* reduceOrScan, const AstNode* op, const AstNode* iterand) { - auto iterType = resolveSerialIterType(resolver, reduceOrScan, iterand); - if (iterType.isUnknown()) return QualifiedType(); - auto opClass = determineReduceScanOp(resolver, reduceOrScan, op, iterType); + auto dt = resolveIterDetails(resolver, reduceOrScan, iterand, {}, + IterDetails::SERIAL); + auto idxType = dt.idxType; + if (idxType.isUnknown()) return QualifiedType(); + auto opClass = determineReduceScanOp(resolver, reduceOrScan, op, idxType); if (opClass == nullptr) return QualifiedType(); return getReduceScanOpResultType(resolver, reduceOrScan, opClass); diff --git a/frontend/lib/resolution/resolution-types.cpp b/frontend/lib/resolution/resolution-types.cpp index c48f72e67d35..ff7782da6414 100644 --- a/frontend/lib/resolution/resolution-types.cpp +++ b/frontend/lib/resolution/resolution-types.cpp @@ -26,6 +26,7 @@ #include "chpl/framework/update-functions.h" #include "chpl/resolution/resolution-queries.h" #include "chpl/resolution/scope-queries.h" +#include "chpl/types/EnumType.h" #include "chpl/types/TupleType.h" #include "chpl/uast/Builder.h" #include "chpl/uast/FnCall.h" @@ -845,6 +846,48 @@ void TypedFnSignature::stringify(std::ostream& ss, ss << ")"; } +bool TypedFnSignature:: +fetchIterKindStr(Context* context, UniqueString& outIterKindStr) const { + if (!isIterator()) return false; + + // Has to just be a serial iterator. + if (numFormals() == 0 || (isMethod() && numFormals() == 1)) return true; + + auto ik = types::EnumType::getIterKindType(context); + auto m = types::EnumType::getParamConstantsMapOrNull(context, ik); + if (m == nullptr) return false; + + QualifiedType tagFormalType; + bool foundTagFormal = false; + UniqueString iterKindStr; + + // Loop over the formals since they could be in any position. + for (int i = 0; i < numFormals(); i++) { + if (formalName(i) == USTR("tag")) { + foundTagFormal = true; + tagFormalType = formalType(i); + if (m != nullptr) { + for (auto& p : *m) { + if (formalType(i) != p.second) continue; + iterKindStr = p.first; + break; + } + } + } + if (foundTagFormal) break; + } + + bool tagFormalMatches = tagFormalType.type() == ik && + tagFormalType.param(); + if (tagFormalMatches) { + CHPL_ASSERT(!iterKindStr.isEmpty()); + outIterKindStr = iterKindStr; + } + + bool ret = !foundTagFormal || tagFormalMatches; + return ret; +} + void CandidatesAndForwardingInfo::stringify( std::ostream& ss, chpl::StringifyKind stringKind) const { ss << "CandidatesAndForwardingInfo: "; diff --git a/frontend/lib/types/EnumType.cpp b/frontend/lib/types/EnumType.cpp index 29d86764da1d..0371d5dcb715 100644 --- a/frontend/lib/types/EnumType.cpp +++ b/frontend/lib/types/EnumType.cpp @@ -64,6 +64,38 @@ const EnumType* EnumType::getBoundKindType(Context* context) { return EnumType::get(context, id, name); } +const EnumType* EnumType::getIterKindType(Context* context) { + auto name = UniqueString::get(context, "iterKind"); + auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelBase", "iterKind"); + return EnumType::get(context, id, name); +} + +static const std::map& +getParamConstantsMapQuery(Context* context, const EnumType* et) { + QUERY_BEGIN(getParamConstantsMapQuery, context, et); + std::map ret; + + auto ast = parsing::idToAst(context, et->id()); + if (auto e = ast->toEnum()) { + for (auto elem : e->enumElements()) { + auto param = EnumParam::get(context, elem->id()); + auto k = UniqueString::get(context, elem->name().str()); + QualifiedType v(QualifiedType::PARAM, et, param); + ret.insert({std::move(k), std::move(v)}); + } + } + + return QUERY_END(ret); +} + +const std::map* +EnumType::getParamConstantsMapOrNull(Context* context, const EnumType* et) { + if (!et || !et->id()) return nullptr; + auto ast = parsing::idToAst(context, et->id()); + if (!ast || !ast->isEnum()) return nullptr; + return &getParamConstantsMapQuery(context, et); +} + void EnumType::stringify(std::ostream& ss, StringifyKind stringKind) const { name().stringify(ss, stringKind); } diff --git a/frontend/test/resolution/testLibrary.cpp b/frontend/test/resolution/testLibrary.cpp index 20e7e61152ee..438726392f11 100644 --- a/frontend/test/resolution/testLibrary.cpp +++ b/frontend/test/resolution/testLibrary.cpp @@ -31,6 +31,8 @@ // These tests exist to check compilation success of certain library features // +// TODO: Lock this test in when we feel like the resolver is ready. +/* static void testHelloWorld() { auto config = getConfigWithHome(); Context ctx(config); @@ -47,6 +49,7 @@ static void testHelloWorld() { assert(guard.numErrors() == 0); } +*/ static void testSerialize() { auto config = getConfigWithHome(); @@ -123,7 +126,7 @@ static void testDeserialize() { } int main() { - testHelloWorld(); + // testHelloWorld(); testSerialize(); testDeserialize(); diff --git a/frontend/test/resolution/testLoopIndexVars.cpp b/frontend/test/resolution/testLoopIndexVars.cpp index ba017afbe290..f26be3b68e54 100644 --- a/frontend/test/resolution/testLoopIndexVars.cpp +++ b/frontend/test/resolution/testLoopIndexVars.cpp @@ -34,6 +34,12 @@ #include +#define ADVANCE_PRESERVING_STANDARD_MODULES_(ctx__) \ + do { \ + ctx__->advanceToNextRevision(false); \ + setupModuleSearchPaths(ctx__, false, false, {}, {}); \ + } while (0) + static auto myiter = std::string(R""""( iter myiter() { yield 1; @@ -116,12 +122,11 @@ R"""( i in myiter() { // // Testing resolution of loop index variables -// TODO: -// - zippered iteration -// - forall loops // - error messages // + + static void testAmbiguous() { printf("testAmbiguous\n"); Context ctx; @@ -472,6 +477,337 @@ static void testIndexScope() { assert(!guard.realizeErrors()); } +static void testIterSigDetection(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + std::string program = + R""""( + + iter i1() { yield 0.0; } + iter i1(param tag: iterKind) where tag == iterKind.standalone { yield 0; } + iter i2(param tag: iterKind) where tag == iterKind.leader { yield (0,0); } + iter i2(param tag: iterKind, followThis) where tag == iterKind.follower { yield ""; } + iter i2() { yield false; } + + for a in i1() do; + forall b in i1() do; + for c in i2() do; + forall d in i2() do; + + )""""; + + auto mod = parseModule(context, program); + auto& rr = resolveModule(context, mod->id()); + assert(!guard.realizeErrors()); + + auto aLoop = parentAst(context, findVariable(mod, "a"))->toIndexableLoop(); + auto aSig1 = rr.byAst(aLoop->iterand()).mostSpecific().only().fn(); + assert(aSig1->isSerialIterator(context)); + + auto bLoop = parentAst(context, findVariable(mod, "b"))->toIndexableLoop(); + auto bSig1 = rr.byAst(bLoop->iterand()).associatedActions()[0].fn(); + assert(bSig1->isParallelStandaloneIterator(context)); + + auto cLoop = parentAst(context, findVariable(mod, "c"))->toIndexableLoop(); + auto cSig1 = rr.byAst(cLoop->iterand()).mostSpecific().only().fn(); + assert(cSig1->isSerialIterator(context)); + + auto dLoop = parentAst(context, findVariable(mod, "d"))->toIndexableLoop(); + auto dSig1 = rr.byAst(dLoop->iterand()).associatedActions()[0].fn(); + assert(dSig1->isParallelLeaderIterator(context)); + auto dSig2 = rr.byAst(dLoop->iterand()).associatedActions()[1].fn(); + assert(dSig2->isParallelFollowerIterator(context)); + + auto m = resolveTypesOfVariables(context, program, { "a", "b", "c", "d" }); + assert(!guard.realizeErrors()); + assert(m["a"].kind() == QualifiedType::CONST_VAR); + assert(m["a"].type()->isRealType()); + assert(m["b"].kind() == QualifiedType::CONST_VAR); + assert(m["b"].type()->isIntType()); + assert(m["c"].kind() == QualifiedType::CONST_VAR); + assert(m["c"].type()->isBoolType()); + assert(m["d"].kind() == QualifiedType::CONST_VAR); + assert(m["d"].type()->isStringType()); +} + +static void +unpackIterKindStrToBool(const std::string& str, + bool* needSerial=nullptr, + bool* needStandalone=nullptr, + bool* needLeader=nullptr, + bool* needFollower=nullptr) { + bool* p = nullptr; + if (str.empty()) { + p = needSerial; + } else if (str == "standalone") { + p = needStandalone; + } else if (str == "leader") { + p = needLeader; + } else if (str == "follower") { + p = needFollower; + } else { + assert(false && "Invalid 'iterKind' string!"); + } + if (p) *p = true; +} + +static void +assertIterIsCorrect(Context* context, const AssociatedAction& aa, + const std::string& iterKindStr) { + bool needSerial = false; + bool needStandalone = false; + bool needLeader = false; + bool needFollower = false; + unpackIterKindStrToBool(iterKindStr, &needSerial, &needStandalone, + &needLeader, + &needFollower); + + assert(aa.action() == AssociatedAction::ITERATE); + assert(aa.fn()); + + auto fn = aa.fn(); + if (needSerial) { + assert(fn->isSerialIterator(context)); + } else if (needStandalone) { + assert(fn->isParallelStandaloneIterator(context)); + } else if (needLeader) { + assert(fn->isParallelLeaderIterator(context)); + } else if (needFollower) { + assert(fn->isParallelFollowerIterator(context)); + } else { + assert(false && "Should not reach here!"); + } +} + +static void +assertLoopMatches(Context* context, const std::string& program, + const char* iterKindStr, + int idxLoopAst, + int idxIterAst, + int idxFollowerIterAst=-1) { + bool needSerial = false; + bool needStandalone = false; + bool needLeader = false; + bool needFollower = false; + unpackIterKindStrToBool(iterKindStr, &needSerial, &needStandalone, + &needLeader, + &needFollower); + needFollower = needFollower || needLeader; + + // Unpack needed ASTs and properties about them. + const Module* m = parseModule(context, program); + auto loop = m->stmt(idxLoopAst)->toIndexableLoop(); + auto iter = m->stmt(idxIterAst)->toFunction(); + auto iterFollower = idxFollowerIterAst > 0 + ? m->stmt(idxFollowerIterAst)->toFunction() + : nullptr; + assert(loop && iter && loop->iterand() && loop->index()); + auto iterand = loop->iterand(); + auto index = loop->index(); + bool isIterMethod = parsing::idIsMethod(context, iter->id()); + // bool isBracketLoop = loop->isBracketLoop(); + // bool isForall = loop->isForall(); + + + // Resolve the module. + auto& rr = resolveModule(context, m->id()); + auto& reIterand = rr.byAst(iterand); + + if (auto zip = iterand->toZip()) { + assert(false && "Zip iterands not handled in this test yet!"); + return; + } else { + int numExpectedActions = (needLeader || needFollower) ? 2 : 1; + assert(reIterand.associatedActions().size() == numExpectedActions); + + auto& aa1 = reIterand.associatedActions()[0]; + assertIterIsCorrect(context, aa1, iterKindStr); + + if (needFollower) { + assert(iterFollower); + auto& aa2 = reIterand.associatedActions()[1]; + assertIterIsCorrect(context, aa2, "follower"); + bool isFollowerMethod = parsing::idIsMethod(context, iterFollower->id()); + assert(isFollowerMethod == isIterMethod); + } + + auto& reIndex = rr.byAst(index); + assert(reIndex.type().type() == IntType::get(context, 0)); + } +} + +static void testSerialZip(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + auto program = R""""( + record r {} + iter r.these() do yield 0; + var r1 = new r(); + for tup in zip(r1, r1) do tup; + )""""; + + const Module* m = parseModule(context, program); + auto iter = m->stmt(1)->toFunction(); + auto var = m->stmt(2)->toVariable(); + auto loop = m->stmt(3)->toFor(); + assert(iter && var && loop && loop->iterand() && loop->index()); + auto index = loop->index(); + auto zip = loop->iterand()->toZip(); + assert(zip); + + auto& rr = resolveModule(context, m->id()); + auto& reZip = rr.byAst(zip); + + assert(reZip.associatedActions().empty()); + + assert(zip->numActuals() == 2); + for (auto actual : zip->actuals()) { + auto& re = rr.byAst(actual); + assert(re.toId() == var->id()); + assert(re.associatedActions().size() == 1); + auto& aa = re.associatedActions().back(); + assert(aa.action() == AssociatedAction::ITERATE); + auto fn = aa.fn(); + assert(fn->untyped()->kind() == Function::ITER); + } + + auto t = reZip.type().type()->toTupleType(); + assert(t && t->numElements() == 2); + assert(t->elementType(0).type() == IntType::get(context, 0)); + assert(t->elementType(1).type() == IntType::get(context, 0)); + + auto& reIndex = rr.byAst(index); + assert(reIndex.type() == reZip.type()); + assert(!guard.realizeErrors()); +} + +static void testParallelZip(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + auto program = R""""( + record r {} + iter r.these(param tag: iterKind) where tag == iterKind.leader do yield (0, 0); + iter r.these(param tag: iterKind, followThis) where tag == iterKind.follower do yield 0; + var r1 = new r(); + forall tup in zip(r1, r1) do tup; + )""""; + + const Module* m = parseModule(context, program); + auto iterLeader = m->stmt(1)->toFunction(); + auto iterFollower = m->stmt(2)->toFunction(); + auto var = m->stmt(3)->toVariable(); + auto loop = m->stmt(4)->toForall(); + assert(iterLeader && iterFollower && var && loop && + loop->iterand() && + loop->iterand()->isZip() && + loop->index()); + auto index = loop->index(); + auto zip = loop->iterand()->toZip(); + + auto& rr = resolveModule(context, m->id()); + auto& reZip = rr.byAst(zip); + + for (auto& e : guard.errors()) { + std::cout << e->message() << std::endl; + } + + assert(reZip.associatedActions().empty()); + + assert(zip->numActuals() == 2); + for (int i = 0; i < zip->numActuals(); i++) { + auto actual = zip->actual(i); + auto& re = rr.byAst(actual); + assert(re.toId() == var->id()); + bool isLeaderActual = (i == 0); + + // Only the first actual should have a leader iterator attached. + if (isLeaderActual) { + assert(re.associatedActions().size() == 2); + auto& aa = re.associatedActions()[0]; + assertIterIsCorrect(context, aa, "leader"); + } else { + assert(re.associatedActions().size() == 1); + } + + // Check all actuals for the follower iterator. + auto& aa = re.associatedActions()[(isLeaderActual ? 1 : 0)]; + assertIterIsCorrect(context, aa, "follower"); + } + + auto t = reZip.type().type()->toTupleType(); + assert(t && t->numElements() == 2); + assert(t->elementType(0).type() == IntType::get(context, 0)); + assert(t->elementType(1).type() == IntType::get(context, 0)); + + auto& reIndex = rr.byAst(index); + assert(reIndex.type() == reZip.type()); + assert(!guard.realizeErrors()); +} + +static void testForallStandaloneThese(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + auto program = R""""( + record r {} + iter r.these(param tag: iterKind) where tag == iterKind.standalone do yield 0; + var r1 = new r(); + forall i in r1 do i; + )""""; + assertLoopMatches(context, program, "standalone", 3, 1); + assert(!guard.realizeErrors()); +} + +static void testForallStandaloneRedirect(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + auto program = R""""( + iter foo(param tag: iterKind) where tag == iterKind.standalone do yield 0; + forall i in foo() do i; + )""""; + assertLoopMatches(context, program, "standalone", 1, 0); + assert(!guard.realizeErrors()); +} + +static void testForallLeaderFollowerThese(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + auto program = R""""( + record r {} + iter r.these(param tag: iterKind) where tag == iterKind.leader do yield (0, 0); + iter r.these(param tag: iterKind, followThis) where tag == iterKind.follower do yield 0; + var r1 = new r(); + forall i in r1 do i; + )""""; + + assertLoopMatches(context, program, "leader", 4, 1, 2); + assert(!guard.realizeErrors()); +} + +static void testForallLeaderFollowerRedirect(Context* context) { + printf("%s\n", __FUNCTION__); + ErrorGuard guard(context); + + ADVANCE_PRESERVING_STANDARD_MODULES_(context); + auto program = R""""( + iter foo(param tag: iterKind) where tag == iterKind.leader do yield (0, 0); + iter foo(param tag: iterKind, followThis) where tag == iterKind.follower do yield 0; + forall i in foo() do i; + )""""; + assertLoopMatches(context, program, "leader", 2, 0, 1); + assert(!guard.realizeErrors()); +} int main() { testSimpleLoop("for"); @@ -490,5 +826,16 @@ int main() { testNestedParamFor(); testIndexScope(); + // Use a single context instance to avoid re-resolving internal modules. + auto ctx = buildStdContext(); + Context* context = ctx.get(); + testIterSigDetection(context); + testSerialZip(context); + testParallelZip(context); + testForallStandaloneThese(context); + testForallStandaloneRedirect(context); + testForallLeaderFollowerThese(context); + testForallLeaderFollowerRedirect(context); + return 0; }