diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs index feca877be92..53f2f991b44 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs @@ -44,12 +44,13 @@ public static AstPipeline Optimize(AstPipeline pipeline) #endregion private readonly AccumulatorSet _accumulators = new AccumulatorSet(); + private AstExpression _element; // normally either "$$ROOT" or "$_v" private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStage groupStage) { try { - if (IsOptimizableGroupStage(groupStage)) + if (IsOptimizableGroupStage(groupStage, out _element)) { var followingStages = GetFollowingStagesToOptimize(pipeline, i + 1); if (followingStages == null) @@ -71,22 +72,22 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag return pipeline; - static bool IsOptimizableGroupStage(AstGroupStage groupStage) + static bool IsOptimizableGroupStage(AstGroupStage groupStage, out AstExpression element) { - // { $group : { _id : ?, _elements : { $push : "$$ROOT" } } } + // { $group : { _id : ?, _elements : { $push : element } } } if (groupStage.Fields.Count == 1) { var field = groupStage.Fields[0]; if (field.Path == "_elements" && field.Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression && - unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push && - unaryAccumulatorExpression.Arg is AstVarExpression varExpression && - varExpression.Name == "ROOT") + unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push) { + element = unaryAccumulatorExpression.Arg; return true; } } + element = null; return false; } @@ -173,7 +174,7 @@ private AstStage OptimizeLimitStage(AstLimitStage stage) private AstStage OptimizeMatchStage(AstMatchStage stage) { - var optimizedFilter = AccumulatorMover.MoveAccumulators(_accumulators, stage.Filter); + var optimizedFilter = AccumulatorMover.MoveAccumulators(_accumulators, _element, stage.Filter); return stage.Update(optimizedFilter); } @@ -201,7 +202,7 @@ private AstProjectStageSpecification OptimizeProjectStageSpecification(AstProjec private AstProjectStageSpecification OptimizeProjectStageSetFieldSpecification(AstProjectStageSetFieldSpecification specification) { - var optimizedValue = AccumulatorMover.MoveAccumulators(_accumulators, specification.Value); + var optimizedValue = AccumulatorMover.MoveAccumulators(_accumulators, _element, specification.Value); return specification.Update(optimizedValue); } @@ -249,27 +250,29 @@ public string AddAccumulatorExpression(AstAccumulatorExpression value) private class AccumulatorMover : AstNodeVisitor { #region static - public static TNode MoveAccumulators(AccumulatorSet accumulators, TNode node) + public static TNode MoveAccumulators(AccumulatorSet accumulators, AstExpression element, TNode node) where TNode : AstNode { - var mover = new AccumulatorMover(accumulators); + var mover = new AccumulatorMover(accumulators, element); return mover.VisitAndConvert(node); } #endregion private readonly AccumulatorSet _accumulators; + private readonly AstExpression _element; - private AccumulatorMover(AccumulatorSet accumulator) + private AccumulatorMover(AccumulatorSet accumulator, AstExpression element) { _accumulators = accumulator; + _element = element; } public override AstNode VisitFilterField(AstFilterField node) { - // "_elements.0.X" => { __agg0 : { $first : "$$ROOT" } } + "__agg0.X" + // "_elements.0.X" => { __agg0 : { $first : element } } + "__agg0.X" if (node.Path.StartsWith("_elements.0.")) { - var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.First, AstExpression.Var("ROOT")); + var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.First, _element); var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); var restOfPath = node.Path.Substring("_elements.0.".Length); var rewrittenPath = $"{accumulatorFieldName}.{restOfPath}"; @@ -288,9 +291,7 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node) { if (node.FieldName is AstConstantExpression constantFieldName && constantFieldName.Value.IsString && - constantFieldName.Value.AsString == "_elements" && - node.Input is AstVarExpression varExpression && - varExpression.Name == "ROOT") + constantFieldName.Value.AsString == "_elements") { throw new UnableToRemoveReferenceToElementsException(); } @@ -300,7 +301,7 @@ node.Input is AstVarExpression varExpression && public override AstNode VisitMapExpression(AstMapExpression node) { - // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => root) } } + "$__agg0" + // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0" if (node.Input is AstGetFieldExpression mapInputGetFieldExpression && mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression && mapInputconstantFieldExpression.Value.IsString && @@ -308,10 +309,10 @@ mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFi mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression && mapInputGetFieldVarExpression.Name == "ROOT") { - var root = AstExpression.Var("ROOT", isCurrent: true); - var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, root)); + var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element)); var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg); var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); + var root = AstExpression.Var("ROOT", isCurrent: true); return AstExpression.GetField(root, accumulatorFieldName); } @@ -321,7 +322,7 @@ mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpressi public override AstNode VisitPickExpression(AstPickExpression node) { // { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } } - // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => root) } } } + "$__agg0" + // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0" if (node.Source is AstGetFieldExpression getFieldExpression && getFieldExpression.Input is AstVarExpression varExpression && varExpression.Name == "ROOT" && @@ -330,10 +331,10 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio constantFieldNameExpression.Value.AsString == "_elements") { var @operator = node.Operator.ToAccumulatorOperator(); - var root = AstExpression.Var("ROOT", isCurrent: true); - var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, root)); + var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element)); var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N); var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); + var root = AstExpression.Var("ROOT", isCurrent: true); return AstExpression.GetField(root, accumulatorFieldName); } @@ -384,7 +385,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression) { - // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : "$$ROOT" } } + "$__agg0" + // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0" if (node.Operator.IsAccumulator(out var accumulatorOperator) && node.Arg is AstGetFieldExpression getFieldExpression && getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression && @@ -393,7 +394,7 @@ getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameE getFieldExpression.Input is AstVarExpression getFieldInputVarExpression && getFieldInputVarExpression.Name == "ROOT") { - var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, root); + var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element); var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); optimizedExpression = AstExpression.GetField(root, accumulatorFieldName); return true; @@ -406,7 +407,7 @@ getFieldExpression.Input is AstVarExpression getFieldInputVarExpression && bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpression) { - // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => root) } } + "$__agg0" + // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0" if (node.Operator.IsAccumulator(out var accumulatorOperator) && node.Arg is AstMapExpression mapExpression && mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression && @@ -416,7 +417,7 @@ mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFi mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression && mapInputGetFieldVarExpression.Name == "ROOT") { - var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, root)); + var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element)); var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg); var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); optimizedExpression = AstExpression.GetField(root, accumulatorFieldName); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs index a40cc7cc3ac..9f411b0c323 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs @@ -1537,8 +1537,8 @@ group f by f.D into g 4, "{ $project : { _v : '$G', _id : 0 } }", "{ $unwind : '$_v' }", - "{ $group : { _id : '$_v.D', _elements : { $push : '$_v' } } }", - "{ $project : { Key : '$_id', SumF : { $sum : '$_elements.E.F' }, _id : 0 } }"); + "{ $group : { _id : '$_v.D', __agg0 : { $sum : '$_v.E.F' } } }", + "{ $project : { Key : '$_id', SumF : '$__agg0', _id : 0 } }"); } [Fact] @@ -1567,8 +1567,8 @@ group s by s.D into g "{ $unwind : '$_v' }", "{ $project : { '_v' : '$_v.S', '_id' : 0 } }", "{ $unwind : '$_v' }", - "{ $group : { _id : '$_v.D', _elements : { $push : '$_v' } } }", - "{ $project : { Key : '$_id', SumF : { $sum : '$_elements.E.F' }, _id : 0 } }"); + "{ $group : { _id : '$_v.D', __agg0 : { $sum : '$_v.E.F' } } }", + "{ $project : { Key : '$_id', SumF : '$__agg0', _id : 0 } }"); } [Fact] diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4048Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4048Tests.cs index 1cff2edaa99..b4b6a3c4ed1 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4048Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4048Tests.cs @@ -311,8 +311,8 @@ public void IGrouping_All_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push: '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $allElementsTrue : { $map : { input : '$_elements', as : 'e', in : { $gt : ['$$e', 0] } } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $push: { $gt : ['$X', 0] } } } }", // MQL could be optimized further + "{ $project : { Id : '$_id', Result : { $allElementsTrue : '$__agg0' }, _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -361,8 +361,8 @@ public void IGrouping_Any_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $gt : [{ $size : '$_elements' }, 0] }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $sum : 1 } } }", + "{ $project : { Id : '$_id', Result : { $gt : ['$__agg0', 0] }, _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -411,8 +411,8 @@ public void IGrouping_Any_with_predicate_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $anyElementTrue : { $map : { input : '$_elements', as : 'e', in : { $gt : ['$$e', 0] } } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $push : { $gt : ['$X', 0] } } } }", // MQL could be optimized further + "{ $project : { Id : '$_id', Result : { $anyElementTrue : '$__agg0' }, _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -461,8 +461,8 @@ public void IGrouping_Average_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $avg : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $avg : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -511,8 +511,8 @@ public void IGrouping_Average_with_selector_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $avg : { $map : { input : '$_elements', as : 'e', in : '$$e' } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $avg : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -661,8 +661,8 @@ public void IGrouping_Count_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $size : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $sum : 1 } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -989,8 +989,8 @@ public void IGrouping_First_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $arrayElemAt : ['$_elements', 0] }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $first : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1195,8 +1195,8 @@ public void IGrouping_Last_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $arrayElemAt : ['$_elements', -1] }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $last : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1351,8 +1351,8 @@ public void IGrouping_LongCount_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $size : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $sum : 1 } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1451,8 +1451,8 @@ public void IGrouping_Max_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $max : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $max : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1501,8 +1501,8 @@ public void IGrouping_Max_with_selector_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $max : { $map : { input : '$_elements', as : 'e', in : '$$e' } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $max : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1551,8 +1551,8 @@ public void IGrouping_Min_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $min : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $min : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1601,8 +1601,8 @@ public void IGrouping_Min_with_selector_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $min : { $map : { input : '$_elements', as : 'e', in : '$$e' } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $min : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1701,8 +1701,8 @@ public void IGrouping_Select_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", - "{ $project : { Id : '$_id', Result : { $map : { input : '$_elements', as : 'e', in : '$$e' } }, _id : 0 } }", // MQL could be optimized further + "{ $group : { _id : '$_id', __agg0 : { $push : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1751,8 +1751,8 @@ public void IGrouping_StandardDeviationPopulation_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $stdDevPop : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $stdDevPop : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1801,8 +1801,8 @@ public void IGrouping_StandardDeviationPopulation_with_selector_of_scalar_should var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $stdDevPop : { $map : { input : '$_elements', as : 'e', in : '$$e' } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $stdDevPop : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1851,8 +1851,8 @@ public void IGrouping_StandardDeviationSample_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $stdDevSamp : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $stdDevSamp : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1901,8 +1901,8 @@ public void IGrouping_StandardDeviationSample_with_selector_of_scalar_should_wor var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $stdDevSamp : { $map : { input : '$_elements', as : 'e', in : '$$e' } } }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $stdDevSamp : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); @@ -1951,8 +1951,8 @@ public void IGrouping_Sum_of_scalar_should_work() var stages = Translate(collection, queryable); var expectedStages = new[] { - "{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further - "{ $project : { Id : '$_id', Result : { $sum : '$_elements' }, _id : 0 } }", + "{ $group : { _id : '$_id', __agg0 : { $sum : '$X' } } }", + "{ $project : { Id : '$_id', Result : '$__agg0', _id : 0 } }", "{ $sort : { Id : 1 } }" }; AssertStages(stages, expectedStages); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4468Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4468Tests.cs new file mode 100644 index 00000000000..330e9c44801 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4468Tests.cs @@ -0,0 +1,131 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Linq; +using MongoDB.Driver.Linq; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira +{ + public class CSharp4468Tests : Linq3IntegrationTest + { + [Theory] + [ParameterAttributeData] + public void Query1_should_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = CreateCollection(linqProvider); + + var queryable = + collection.AsQueryable() + .SelectMany(i => i.Lines) + .GroupBy(l => l.ItemId) + .Select(g => new ItemSummary + { + Id = g.Key, + TotalAmount = g.Sum(l => l.TotalAmount) + }); + + var stages = Translate(collection, queryable); + string[] expectedStages; + if (linqProvider == LinqProvider.V2) + { + expectedStages = new[] + { + "{ $unwind : '$Lines' }", + "{ $project : { Lines : '$Lines', _id : 0 } }", + "{ $group : { _id : '$Lines.ItemId', __agg0 : { $sum : '$Lines.TotalAmount' } } }", + "{ $project : { Id : '$_id', TotalAmount : '$__agg0', _id : 0 } }" + }; + } + else + { + expectedStages = new[] + { + "{ $project : { _v : '$Lines', _id : 0 } }", + "{ $unwind : '$_v' }", + "{ $group : { _id : '$_v.ItemId', __agg0 : { $sum : '$_v.TotalAmount' } } }", + "{ $project : { _id : '$_id', TotalAmount : '$__agg0' } }" + }; + } + AssertStages(stages, expectedStages); + } + + [Theory] + [ParameterAttributeData] + public void Query2_should_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = CreateCollection(linqProvider); + + var queryable = + collection.AsQueryable() + .GroupBy(l => l.Id) + .Select(g => new ItemSummary + { + Id = g.Key, + TotalAmount = g.Sum(l => l.TotalAmount) + }); + + var stages = Translate(collection, queryable); + string[] expectedStages; + if (linqProvider == LinqProvider.V2) + { + expectedStages = new[] + { + "{ $group : { _id : '$_id', __agg0 : { $sum : '$TotalAmount' } } }", + "{ $project : { Id : '$_id', TotalAmount : '$__agg0', _id : 0 } }" + }; + } + else + { + expectedStages = new[] + { + "{ $group : { _id : '$_id', __agg0 : { $sum : '$TotalAmount' } } }", + "{ $project : { _id : '$_id', TotalAmount : '$__agg0' } }" // only difference from LINQ2 is "_id" vs "Id" + }; + } + AssertStages(stages, expectedStages); + } + + private IMongoCollection CreateCollection(LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider: linqProvider); + return collection; + } + + public class OrderDao + { + public OrderLineDao[] Lines { get; set; } + + public decimal TotalAmount { get; set; } + public Guid Id { get; set; } + } + + public class OrderLineDao + { + public decimal TotalAmount { get; set; } + public Guid ItemId { get; set; } + } + + public class ItemSummary + { + public Guid Id { get; set; } + public decimal TotalAmount { get; set; } + } + } +}