Skip to content

Commit

Permalink
CSHARP-4453: Support Bucket and BucketAuto stages in LINQ3.
Browse files Browse the repository at this point in the history
  • Loading branch information
rstam committed Jan 26, 2023
1 parent ec46c34 commit 8993daa
Show file tree
Hide file tree
Showing 20 changed files with 741 additions and 514 deletions.
103 changes: 103 additions & 0 deletions src/MongoDB.Driver/AggregateBucketAutoResultIdSerializer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using MongoDB.Bson;
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver
{
/// <summary>
/// Static factory class for AggregateBucketAutoResultIdSerializer.
/// </summary>
public static class AggregateBucketAutoResultIdSerializer
{
/// <summary>
/// Creates an instance of AggregateBucketAutoResultIdSerializer.
/// </summary>
/// <typeparam name="TValue">The value type.</typeparam>
/// <param name="valueSerializer">The value serializer.</param>
/// <returns>A AggregateBucketAutoResultIdSerializer.</returns>
public static IBsonSerializer<AggregateBucketAutoResultId<TValue>> Create<TValue>(IBsonSerializer<TValue> valueSerializer)
{
return new AggregateBucketAutoResultIdSerializer<TValue>(valueSerializer);
}
}

/// <summary>
/// A serializer for AggregateBucketAutoResultId.
/// </summary>
/// <typeparam name="TValue">The type of the values.</typeparam>
public class AggregateBucketAutoResultIdSerializer<TValue> : ClassSerializerBase<AggregateBucketAutoResultId<TValue>>, IBsonDocumentSerializer
{
private readonly IBsonSerializer<TValue> _valueSerializer;

/// <summary>
/// Initializes a new instance of the <see cref="AggregateBucketAutoResultIdSerializer{TValue}"/> class.
/// </summary>
/// <param name="valueSerializer">The value serializer.</param>
public AggregateBucketAutoResultIdSerializer(IBsonSerializer<TValue> valueSerializer)
{
_valueSerializer = Ensure.IsNotNull(valueSerializer, nameof(valueSerializer));
}

/// <inheritdoc/>
protected override AggregateBucketAutoResultId<TValue> DeserializeValue(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var reader = context.Reader;
reader.ReadStartDocument();
TValue min = default;
TValue max = default;
while (reader.ReadBsonType() != 0)
{
var name = reader.ReadName();
switch (name)
{
case "min": min = _valueSerializer.Deserialize(context); break;
case "max": max = _valueSerializer.Deserialize(context); break;
default: throw new BsonSerializationException($"Invalid element name for AggregateBucketAutoResultId: {name}.");
}
}
reader.ReadEndDocument();
return new AggregateBucketAutoResultId<TValue>(min, max);
}

/// <inheritdoc/>
protected override void SerializeValue(BsonSerializationContext context, BsonSerializationArgs args, AggregateBucketAutoResultId<TValue> value)
{
var writer = context.Writer;
writer.WriteStartDocument();
writer.WriteName("min");
_valueSerializer.Serialize(context, value.Min);
writer.WriteName("max");
_valueSerializer.Serialize(context, value.Max);
writer.WriteEndDocument();
}

/// <inheritdoc/>
public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo)
{
serializationInfo = memberName switch
{
"Min" => new BsonSerializationInfo("min", _valueSerializer, _valueSerializer.ValueType),
"Max" => new BsonSerializationInfo("max", _valueSerializer, _valueSerializer.ValueType),
_ => null
};
return serializationInfo != null;
}
}
}
57 changes: 0 additions & 57 deletions src/MongoDB.Driver/GroupForLinq3Result.cs

This file was deleted.

47 changes: 36 additions & 11 deletions src/MongoDB.Driver/IAggregateFluentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static IAggregateFluent<AggregateBucketAutoResult<TValue>> BucketAuto<TRe
}

/// <summary>
/// Appends a $bucketAuto stage to the pipeline.
/// Appends a $bucketAuto stage to the pipeline (this overload can only be used with LINQ3).
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <typeparam name="TValue">The type of the value.</typeparam>
Expand All @@ -110,13 +110,46 @@ public static IAggregateFluent<TNewResult> BucketAuto<TResult, TValue, TNewResul
this IAggregateFluent<TResult> aggregate,
Expression<Func<TResult, TValue>> groupBy,
int buckets,
Expression<Func<IGrouping<TValue, TResult>, TNewResult>> output,
Expression<Func<IGrouping<AggregateBucketAutoResultId<TValue>, TResult>, TNewResult>> output,
AggregateBucketAutoOptions options = null)
{
Ensure.IsNotNull(aggregate, nameof(aggregate));
if (aggregate.Database.Client.Settings.LinqProvider != LinqProvider.V3)
{
throw new InvalidOperationException("This overload of BucketAuto can only be used with LINQ3.");
}

return aggregate.AppendStage(PipelineStageDefinitionBuilder.BucketAuto(groupBy, buckets, output, options));
}

/// <summary>
/// Appends a $bucketAuto stage to the pipeline (this method can only be used with LINQ2).
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <typeparam name="TValue">The type of the value.</typeparam>
/// <typeparam name="TNewResult">The type of the new result.</typeparam>
/// <param name="aggregate">The aggregate.</param>
/// <param name="groupBy">The expression providing the value to group by.</param>
/// <param name="buckets">The number of buckets.</param>
/// <param name="output">The output projection.</param>
/// <param name="options">The options (optional).</param>
/// <returns>The fluent aggregate interface.</returns>
public static IAggregateFluent<TNewResult> BucketAutoForLinq2<TResult, TValue, TNewResult>(
this IAggregateFluent<TResult> aggregate,
Expression<Func<TResult, TValue>> groupBy,
int buckets,
Expression<Func<IGrouping<TValue, TResult>, TNewResult>> output, // the IGrouping for BucketAuto has been wrong all along, only fixing it for LINQ3
AggregateBucketAutoOptions options = null)
{
Ensure.IsNotNull(aggregate, nameof(aggregate));
if (aggregate.Database.Client.Settings.LinqProvider != LinqProvider.V2)
{
throw new InvalidOperationException("The BucketAutoForLinq2 method can only be used with LINQ2.");
}

return aggregate.AppendStage(PipelineStageDefinitionBuilder.BucketAutoForLinq2(groupBy, buckets, output, options));
}

/// <summary>
/// Appends a $densify stage to the pipeline.
/// </summary>
Expand Down Expand Up @@ -396,15 +429,7 @@ public static IAggregateFluent<BsonDocument> Group<TResult>(this IAggregateFluen
public static IAggregateFluent<TNewResult> Group<TResult, TKey, TNewResult>(this IAggregateFluent<TResult> aggregate, Expression<Func<TResult, TKey>> id, Expression<Func<IGrouping<TKey, TResult>, TNewResult>> group)
{
Ensure.IsNotNull(aggregate, nameof(aggregate));
if (aggregate.Database.Client.Settings.LinqProvider == LinqProvider.V2)
{
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Group(id, group));
}
else
{
var (groupStage, projectStage) = PipelineStageDefinitionBuilder.GroupForLinq3(id, group);
return aggregate.AppendStage(groupStage).AppendStage(projectStage);
}
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Group(id, group));
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,50 @@

namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers
{
internal class AstGroupPipelineOptimizer
internal class AstGroupingPipelineOptimizer
{
#region static
public static AstPipeline Optimize(AstPipeline pipeline)
{
var optimizer = new AstGroupPipelineOptimizer();
var optimizer = new AstGroupingPipelineOptimizer();
for (var i = 0; i < pipeline.Stages.Count; i++)
{
var stage = pipeline.Stages[i];
if (stage is AstGroupStage groupStage)
if (IsGroupingStage(stage))
{
pipeline = optimizer.OptimizeGroupStage(pipeline, i, groupStage);
pipeline = optimizer.OptimizeGroupingStage(pipeline, i, stage);
}
}

return pipeline;

static bool IsGroupingStage(AstStage stage)
{
return stage.NodeType switch
{
AstNodeType.GroupStage or AstNodeType.BucketStage or AstNodeType.BucketAutoStage => true,
_ => false
};
}
}
#endregion

private readonly AccumulatorSet _accumulators = new AccumulatorSet();
private AstExpression _element; // normally either "$$ROOT" or "$_v"

private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStage groupStage)
private AstPipeline OptimizeGroupingStage(AstPipeline pipeline, int i, AstStage groupingStage)
{
try
{
if (IsOptimizableGroupStage(groupStage, out _element))
if (IsOptimizableGroupingStage(groupingStage, out _element))
{
var followingStages = GetFollowingStagesToOptimize(pipeline, i + 1);
if (followingStages == null)
{
return pipeline;
}

var mappings = OptimizeGroupAndFollowingStages(groupStage, followingStages);
var mappings = OptimizeGroupingAndFollowingStages(groupingStage, followingStages);
if (mappings.Length > 0)
{
return (AstPipeline)AstNodeReplacer.Replace(pipeline, mappings);
Expand All @@ -72,23 +81,57 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag

return pipeline;

static bool IsOptimizableGroupStage(AstGroupStage groupStage, out AstExpression element)
static bool IsOptimizableGroupingStage(AstStage groupingStage, out AstExpression element)
{
// { $group : { _id : ?, _elements : { $push : element } } }
if (groupStage.Fields.Count == 1)
if (groupingStage is AstGroupStage groupStage)
{
// { $group : { _id : ?, _elements : { $push : element } } }
if (groupStage.Fields.Count == 1)
{
var field = groupStage.Fields[0];
return IsElementsPush(field, out element);
}
}

if (groupingStage is AstBucketStage bucketStage)
{
// { $bucket : { groupBy : ?, boundaries : ?, default : ?, output : { _elements : { $push : element } } } }
if (bucketStage.Output.Count == 1)
{
var output = bucketStage.Output[0];
return IsElementsPush(output, out element);
}
}

if (groupingStage is AstBucketAutoStage bucketAutoStage)
{
var field = groupStage.Fields[0];
if (field.Path == "_elements" &&
// { $bucketAuto : { groupBy : ?, buckets : ?, granularity : ?, output : { _elements : { $push : element } } } }
if (bucketAutoStage.Output.Count == 1)
{
var output = bucketAutoStage.Output[0];
return IsElementsPush(output, out element);
}
}

element = null;
return false;

static bool IsElementsPush(AstAccumulatorField field, out AstExpression element)
{
if (
field.Path == "_elements" &&
field.Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push)
{
element = unaryAccumulatorExpression.Arg;
return true;
}
else
{
element = null;
return false;
}
}

element = null;
return false;
}

static List<AstStage> GetFollowingStagesToOptimize(AstPipeline pipeline, int from)
Expand Down Expand Up @@ -135,7 +178,7 @@ static bool IsLastStageThatCanBeOptimized(AstStage stage)
}
}

private (AstNode, AstNode)[] OptimizeGroupAndFollowingStages(AstGroupStage groupStage, List<AstStage> followingStages)
private (AstNode, AstNode)[] OptimizeGroupingAndFollowingStages(AstStage groupingStage, List<AstStage> followingStages)
{
var mappings = new List<(AstNode, AstNode)>();

Expand All @@ -148,10 +191,21 @@ static bool IsLastStageThatCanBeOptimized(AstStage stage)
}
}

var newGroupStage = AstStage.Group(groupStage.Id, _accumulators);
mappings.Add((groupStage, newGroupStage));
var newGroupingStage = CreateNewGroupingStage(groupingStage, _accumulators);
mappings.Add((groupingStage, newGroupingStage));

return mappings.ToArray();

static AstStage CreateNewGroupingStage(AstStage groupingStage, AccumulatorSet accumulators)
{
return groupingStage switch
{
AstGroupStage groupStage => AstStage.Group(groupStage.Id, accumulators),
AstBucketStage bucketStage => AstStage.Bucket(bucketStage.GroupBy, bucketStage.Boundaries, bucketStage.Default, accumulators),
AstBucketAutoStage bucketAutoStage => AstStage.BucketAuto(bucketAutoStage.GroupBy, bucketAutoStage.Buckets, bucketAutoStage.Granularity, accumulators),
_ => throw new Exception($"Unexpected {nameof(groupingStage)} node type: {groupingStage.NodeType}.")
};
}
}

private AstStage OptimizeFollowingStage(AstStage stage)
Expand Down
Loading

0 comments on commit 8993daa

Please sign in to comment.