From 15d83a7d08f2f70d98653355e0d219e1a878b475 Mon Sep 17 00:00:00 2001 From: leminh98 Date: Mon, 1 Apr 2024 12:18:15 -0700 Subject: [PATCH] Query: Adds translation support for single key single value select GROUP BY LINQ queries (#4074) * preliminary change * Add some more boiler plate code * move all linq test to the same folder; add some groupBy test * fix references error in test refactoring add code for group by substitution. Still need to adjust binding post groupby * preliminary for the groupby functions with key and value selector * trying to change collection inputs for group by * WIP bookmark * Successfully ignore "key" * clean up code * Sucessfully bind the case of group by with only key selector and no value selector followed by an optional select clause * enable one group by test * add support for aggregate value selector * added baseline * working on adding support for multivalue value selector and key selector * code clean up * more clean up * more clean up * update test * Move test to separate file * code clean up * remove baseline file that got moved * fix merge issue * Changes test infrastructure to reflect changes from Master * address code review part 1 * Address code review 2 and adds code coverage * Addressed code review and added tests. Still a couple of bugs to iron out * Fix group by translation issue and add more test * update comments * address pr comment --------- Co-authored-by: Minh Le Co-authored-by: Aditya --- .../src/Linq/ExpressionToSQL.cs | 157 +- .../src/Linq/QueryUnderConstruction.cs | 143 +- .../src/Linq/TranslationContext.cs | 35 +- Microsoft.Azure.Cosmos/src/Linq/Utilities.cs | 7 +- .../Visitors/SqlObjectTextSerializer.cs | 1 + ...alBaselineTests.TestGroupByTranslation.xml | 449 +++++ .../Linq/LinqGeneralBaselineTests.cs | 149 ++ .../Linq/LinqTestsCommon.cs | 1774 ++++++++--------- ...icrosoft.Azure.Cosmos.EmulatorTests.csproj | 3 + ...upByClauseSqlParserBaselineTests.Tests.xml | 6 +- ...lObjectVisitorBaselineTests.SqlQueries.xml | 28 +- 11 files changed, 1788 insertions(+), 964 deletions(-) create mode 100644 Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqGeneralBaselineTests.TestGroupByTranslation.xml diff --git a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs index 62b9132a8..5dac768dc 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs @@ -8,12 +8,15 @@ namespace Microsoft.Azure.Cosmos.Linq using System.Collections.Generic; using System.Collections.Immutable; using System.Collections.ObjectModel; + using System.Data.Common; using System.Diagnostics; using System.Globalization; using System.Linq; using System.Linq.Expressions; using System.Reflection; + using System.Text.RegularExpressions; using Microsoft.Azure.Cosmos.CosmosElements; + using Microsoft.Azure.Cosmos.Serialization.HybridRow; using Microsoft.Azure.Cosmos.Serializer; using Microsoft.Azure.Cosmos.Spatial; using Microsoft.Azure.Cosmos.SqlObjects; @@ -64,6 +67,7 @@ namespace Microsoft.Azure.Cosmos.Linq public const string FirstOrDefault = "FirstOrDefault"; public const string Max = "Max"; public const string Min = "Min"; + public const string GroupBy = "GroupBy"; public const string OrderBy = "OrderBy"; public const string OrderByDescending = "OrderByDescending"; public const string Select = "Select"; @@ -109,7 +113,7 @@ namespace Microsoft.Azure.Cosmos.Linq /// /// Translate an expression into a query. - /// Query is constructed as a side-effect in context.currentQuery. + /// Query is constructed as a side-effect in context.CurrentQuery. /// /// Expression to translate. /// Context for translation. @@ -805,8 +809,8 @@ namespace Microsoft.Azure.Cosmos.Linq if (usePropertyRef) { - SqlIdentifier propertyIdnetifier = SqlIdentifier.Create(memberName); - SqlPropertyRefScalarExpression propertyRefExpression = SqlPropertyRefScalarExpression.Create(memberExpression, propertyIdnetifier); + SqlIdentifier propertyIdentifier = SqlIdentifier.Create(memberName); + SqlPropertyRefScalarExpression propertyRefExpression = SqlPropertyRefScalarExpression.Create(memberExpression, propertyIdentifier); return propertyRefExpression; } else @@ -997,7 +1001,7 @@ namespace Microsoft.Azure.Cosmos.Linq SqlQuery query = context.CurrentQuery.FlattenAsPossible().GetSqlQuery(); SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); - ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); Binding binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); @@ -1111,7 +1115,7 @@ namespace Microsoft.Azure.Cosmos.Linq Collection collection = ExpressionToSql.ConvertToCollection(body); context.PushCollection(collection); - ParameterExpression parameter = context.GenFreshParameter(type, parameterName); + ParameterExpression parameter = context.GenerateFreshParameter(type, parameterName); context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); context.PopParameter(); context.PopCollection(); @@ -1120,7 +1124,7 @@ namespace Microsoft.Azure.Cosmos.Linq } /// - /// Visit a method call, construct the corresponding query in context.currentQuery. + /// Visit a method call, construct the corresponding query in context.CurrentQuery. /// At ExpressionToSql point only LINQ method calls are allowed. /// These methods are static extension methods of IQueryable or IEnumerable. /// @@ -1149,11 +1153,18 @@ namespace Microsoft.Azure.Cosmos.Linq Type inputElementType = TypeSystem.GetElementType(inputCollection.Type); Collection collection = ExpressionToSql.Translate(inputCollection, context); + context.PushCollection(collection); Collection result = new Collection(inputExpression.Method.Name); bool shouldBeOnNewQuery = context.CurrentQuery.ShouldBeOnNewQuery(inputExpression.Method.Name, inputExpression.Arguments.Count); context.PushSubqueryBinding(shouldBeOnNewQuery); + + if (context.LastExpressionIsGroupBy) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Group By cannot be followed by other methods")); + } + switch (inputExpression.Method.Name) { case LinqMethods.Any: @@ -1219,6 +1230,13 @@ namespace Microsoft.Azure.Cosmos.Linq context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } + case LinqMethods.GroupBy: + { + context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); + result = ExpressionToSql.VisitGroupBy(returnElementType, inputExpression.Arguments, context); + context.LastExpressionIsGroupBy = true; + break; + } case LinqMethods.OrderBy: { SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); @@ -1376,6 +1394,7 @@ namespace Microsoft.Azure.Cosmos.Linq case LinqMethods.Skip: case LinqMethods.Take: case LinqMethods.Distinct: + case LinqMethods.GroupBy: isSubqueryExpression = true; expressionObjKind = SubqueryKind.ArrayScalarExpression; break; @@ -1405,7 +1424,7 @@ namespace Microsoft.Azure.Cosmos.Linq } /// - /// Visit an lambda expression which is in side a lambda and translate it to a scalar expression or a collection scalar expression. + /// Visit an lambda expression which is inside a lambda and translate it to a scalar expression or a collection scalar expression. /// If it is a collection scalar expression, e.g. should be translated to subquery such as SELECT VALUE ARRAY, SELECT VALUE EXISTS, /// SELECT VALUE [aggregate], the subquery will be aliased to a new binding for the FROM clause. E.g. consider /// Select(family => family.Children.Select(child => child.Grade)). Since the inner Select corresponds to a subquery, this method would @@ -1508,7 +1527,7 @@ namespace Microsoft.Azure.Cosmos.Linq { SqlQuery query = ExpressionToSql.CreateSubquery(expression, parameters, context); - ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); SqlCollection subqueryCollection = ExpressionToSql.CreateSubquerySqlCollection( query, isMinMaxAvgMethod ? SubqueryKind.ArrayScalarExpression : expressionObjKind.Value); @@ -1585,7 +1604,7 @@ namespace Microsoft.Azure.Cosmos.Linq QueryUnderConstruction queryBeforeVisit = context.CurrentQuery; QueryUnderConstruction packagedQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc(), context.CurrentQuery); - packagedQuery.fromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope); + packagedQuery.FromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope); context.CurrentQuery = packagedQuery; if (shouldBeOnNewQuery) context.CurrentSubqueryBinding.ShouldBeOnNewQuery = false; @@ -1663,9 +1682,108 @@ namespace Microsoft.Azure.Cosmos.Linq Binding binding; SqlQuery query = ExpressionToSql.CreateSubquery(lambda.Body, lambda.Parameters, context); SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); - ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); - context.CurrentQuery.fromParameters.Add(binding); + context.CurrentQuery.FromParameters.Add(binding); + } + + return collection; + } + + private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 3) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.GroupBy, 3, arguments.Count)); + } + + // bind the parameters in the value selector to the current input + foreach (ParameterExpression par in Utilities.GetLambda(arguments[2]).Parameters) + { + context.PushParameter(par, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + } + + // First argument is input, second is key selector and third is value selector + LambdaExpression keySelectorLambda = Utilities.GetLambda(arguments[1]); + + // Current GroupBy doesn't allow subquery, so we need to visit non subquery scalar lambda + SqlScalarExpression keySelectorFunc = ExpressionToSql.VisitNonSubqueryScalarLambda(keySelectorLambda, context); + + SqlGroupByClause groupby = SqlGroupByClause.Create(keySelectorFunc); + + context.CurrentQuery = context.CurrentQuery.AddGroupByClause(groupby, context); + + // Create a GroupBy collection and bind the new GroupBy collection to the new parameters created from the key + Collection collection = ExpressionToSql.ConvertToCollection(keySelectorFunc); + collection.isOuter = true; + collection.Name = "GroupBy"; + + ParameterExpression parameterExpression = context.GenerateFreshParameter(returnElementType, keySelectorFunc.ToString(), includeSuffix: false); + Binding binding = new Binding(parameterExpression, collection.inner, isInCollection: false, isInputParameter: true); + + context.CurrentQuery.GroupByParameter = new FromParameterBindings(); + context.CurrentQuery.GroupByParameter.Add(binding); + + // The alias for the key in the value selector lambda is the first arguemt lambda - we bound it to the parameter expression, which already has substitution + ParameterExpression valueSelectorKeyExpressionAlias = Utilities.GetLambda(arguments[2]).Parameters[0]; + context.GroupByKeySubstitution.AddSubstitution(valueSelectorKeyExpressionAlias, parameterExpression/*Utilities.GetLambda(arguments[1]).Body*/); + + // Translate the body of the value selector lambda + Expression valueSelectorExpression = Utilities.GetLambda(arguments[2]).Body; + + // The value selector function needs to be either a MethodCall or an AnonymousType + switch (valueSelectorExpression.NodeType) + { + case ExpressionType.Constant: + { + ConstantExpression constantExpression = (ConstantExpression)valueSelectorExpression; + SqlScalarExpression selectExpression = ExpressionToSql.VisitConstant(constantExpression, context); + + SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(selectExpression); + SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case ExpressionType.Parameter: + { + ParameterExpression parameterValueExpression = (ParameterExpression)valueSelectorExpression; + SqlScalarExpression selectExpression = ExpressionToSql.VisitParameter(parameterValueExpression, context); + + SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(selectExpression); + SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case ExpressionType.Call: + { + // Single Value Selector + MethodCallExpression methodCallExpression = (MethodCallExpression)valueSelectorExpression; + switch (methodCallExpression.Method.Name) + { + case LinqMethods.Max: + case LinqMethods.Min: + case LinqMethods.Average: + case LinqMethods.Count: + case LinqMethods.Sum: + ExpressionToSql.VisitMethodCall(methodCallExpression, context); + break; + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, methodCallExpression.Method.Name)); + } + + break; + } + case ExpressionType.New: + // TODO: Multi Value Selector + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, ExpressionType.New)); + + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, valueSelectorExpression.NodeType)); + } + + foreach (ParameterExpression par in Utilities.GetLambda(arguments[2]).Parameters) + { + context.PopParameter(); } return collection; @@ -1700,7 +1818,7 @@ namespace Microsoft.Azure.Cosmos.Linq // it is necessary to trigger the binding because Skip is just a spec with no binding on its own. // This can be done by pushing and popping a temporary parameter. E.g. In SelectMany(f => f.Children.Skip(1)), // it's necessary to consider Skip as Skip(x => x, 1) to bind x to f.Children. Similarly for Top and Limit. - ParameterExpression parameter = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); context.PopParameter(); @@ -1848,16 +1966,21 @@ namespace Microsoft.Azure.Cosmos.Linq SqlScalarExpression aggregateExpression; if (arguments.Count == 1) { - // Need to trigger parameter binding for cases where a aggregate function immediately follows a member access. - ParameterExpression parameter = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + // Need to trigger parameter binding for cases where an aggregate function immediately follows a member access. + ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + + // If there is a groupby, since there is no argument to the aggregate, we consider it to be invoked on the source collection, and not the group by keys aggregateExpression = ExpressionToSql.VisitParameter(parameter, context); context.PopParameter(); } else if (arguments.Count == 2) - { + { LambdaExpression lambda = Utilities.GetLambda(arguments[1]); - aggregateExpression = ExpressionToSql.VisitScalarExpression(lambda, context); + + aggregateExpression = context.CurrentQuery.GroupByParameter != null + ? ExpressionToSql.VisitNonSubqueryScalarLambda(lambda, context) + : ExpressionToSql.VisitScalarExpression(lambda, context); } else { @@ -1884,7 +2007,7 @@ namespace Microsoft.Azure.Cosmos.Linq // We consider Distinct as Distinct(v0 => v0) // It's necessary to visit this identity method to replace the parameters names - ParameterExpression parameter = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); LambdaExpression identityLambda = Expression.Lambda(parameter, parameter); SqlScalarExpression sqlfunc = ExpressionToSql.VisitNonSubqueryScalarLambda(identityLambda, context); SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(sqlfunc); diff --git a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs index d2e19046e..25129d7ee 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs @@ -27,7 +27,16 @@ namespace Microsoft.Azure.Cosmos.Linq /// /// Binding for the FROM parameters. /// - public FromParameterBindings fromParameters + public FromParameterBindings FromParameters + { + get; + set; + } + + /// + /// Binding for the Group By clause. + /// + public FromParameterBindings GroupByParameter { get; set; @@ -51,6 +60,7 @@ namespace Microsoft.Azure.Cosmos.Linq private SqlSelectClause selectClause; private SqlWhereClause whereClause; private SqlOrderByClause orderByClause; + private SqlGroupByClause groupByClause; // The specs could be in clauses to reflect the SqlQuery. // However, they are separated to avoid update recreation of the readonly DOMs and lengthy code. @@ -61,7 +71,7 @@ namespace Microsoft.Azure.Cosmos.Linq private Lazy alias; /// - /// Input subquery. + /// Input subquery / query to the left of the current query. /// private QueryUnderConstruction inputQuery; @@ -72,7 +82,7 @@ namespace Microsoft.Azure.Cosmos.Linq public QueryUnderConstruction(Func aliasCreatorFunc, QueryUnderConstruction inputQuery) { - this.fromParameters = new FromParameterBindings(); + this.FromParameters = new FromParameterBindings(); this.aliasCreatorFunc = aliasCreatorFunc; this.inputQuery = inputQuery; this.alias = new Lazy(() => aliasCreatorFunc(QueryUnderConstruction.DefaultSubqueryRoot)); @@ -85,22 +95,22 @@ namespace Microsoft.Azure.Cosmos.Linq public void AddBinding(Binding binding) { - this.fromParameters.Add(binding); + this.FromParameters.Add(binding); } public ParameterExpression GetInputParameterInContext(bool isInNewQuery) { - return isInNewQuery ? this.Alias : this.fromParameters.GetInputParameter(); + return isInNewQuery ? this.Alias : this.FromParameters.GetInputParameter(); } /// /// Create a FROM clause from a set of FROM parameter bindings. /// /// The created FROM clause. - private SqlFromClause CreateFrom(SqlCollectionExpression inputCollectionExpression) + private SqlFromClause CreateFromClause(SqlCollectionExpression inputCollectionExpression) { bool first = true; - foreach (Binding paramDef in this.fromParameters.GetBindings()) + foreach (Binding paramDef in this.FromParameters.GetBindings()) { // If input collection expression is provided, the first binding, // which is the input paramter name, should be omitted. @@ -147,7 +157,7 @@ namespace Microsoft.Azure.Cosmos.Linq ParameterExpression inputParam = this.inputQuery.Alias; SqlIdentifier identifier = SqlIdentifier.Create(inputParam.Name); SqlAliasedCollectionExpression colExp = SqlAliasedCollectionExpression.Create(collection, identifier); - SqlFromClause fromClause = this.CreateFrom(colExp); + SqlFromClause fromClause = this.CreateFromClause(colExp); return fromClause; } @@ -169,7 +179,7 @@ namespace Microsoft.Azure.Cosmos.Linq } else { - fromClause = this.CreateFrom(inputCollectionExpression: null); + fromClause = this.CreateFromClause(inputCollectionExpression: null); } // Create a SqlSelectClause with the topSpec. @@ -178,7 +188,7 @@ namespace Microsoft.Azure.Cosmos.Linq SqlSelectClause selectClause = this.selectClause; if (selectClause == null) { - string parameterName = this.fromParameters.GetInputParameter().Name; + string parameterName = this.FromParameters.GetInputParameter().Name; SqlScalarExpression parameterExpression = SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterName)); selectClause = this.selectClause = SqlSelectClause.Create(SqlSelectValueSpec.Create(parameterExpression)); } @@ -186,7 +196,7 @@ namespace Microsoft.Azure.Cosmos.Linq SqlOffsetLimitClause offsetLimitClause = (this.offsetSpec != null) ? SqlOffsetLimitClause.Create(this.offsetSpec, this.limitSpec ?? SqlLimitSpec.Create(SqlNumberLiteral.Create(int.MaxValue))) : offsetLimitClause = default(SqlOffsetLimitClause); - SqlQuery result = SqlQuery.Create(selectClause, fromClause, this.whereClause, /*GroupBy*/ null, this.orderByClause, offsetLimitClause); + SqlQuery result = SqlQuery.Create(selectClause, fromClause, this.whereClause, this.groupByClause, this.orderByClause, offsetLimitClause); return result; } @@ -198,7 +208,7 @@ namespace Microsoft.Azure.Cosmos.Linq public QueryUnderConstruction PackageQuery(HashSet inScope) { QueryUnderConstruction result = new QueryUnderConstruction(this.aliasCreatorFunc); - result.fromParameters.SetInputParameter(typeof(object), this.Alias.Name, inScope); + result.FromParameters.SetInputParameter(typeof(object), this.Alias.Name, inScope); result.inputQuery = this; return result; } @@ -214,13 +224,14 @@ namespace Microsoft.Azure.Cosmos.Linq // 1. Select clause appears after Distinct // 2. There are any operations after Take that is not a pure Select. // 3. There are nested Select, Where or OrderBy + // 4. Group by clause appears after Select QueryUnderConstruction parentQuery = null; QueryUnderConstruction flattenQuery = null; bool seenSelect = false; bool seenAnyNonSelectOp = false; for (QueryUnderConstruction query = this; query != null; query = query.inputQuery) { - foreach (Binding binding in query.fromParameters.GetBindings()) + foreach (Binding binding in query.FromParameters.GetBindings()) { if ((binding.ParameterDefinition != null) && (binding.ParameterDefinition is SqlSubqueryCollection)) { @@ -232,8 +243,15 @@ namespace Microsoft.Azure.Cosmos.Linq // In Select -> SelectMany cases, fromParameter substitution is not yet supported . // Therefore these are un-flattenable. if (query.inputQuery != null && - (query.fromParameters.GetBindings().First().Parameter.Name == query.inputQuery.Alias.Name) && - query.fromParameters.GetBindings().Any(b => b.ParameterDefinition != null)) + (query.FromParameters.GetBindings().First().Parameter.Name == query.inputQuery.Alias.Name) && + query.FromParameters.GetBindings().Any(b => b.ParameterDefinition != null)) + { + flattenQuery = this; + break; + } + + // In case of Select -> Group by cases, the Select query should not be flattened and kept as a subquery + if ((query.inputQuery?.selectClause != null) && (query.groupByClause != null)) { flattenQuery = this; break; @@ -253,10 +271,12 @@ namespace Microsoft.Azure.Cosmos.Linq seenAnyNonSelectOp |= (query.whereClause != null) || (query.orderByClause != null) || + (query.groupByClause != null) || (query.topSpec != null) || (query.offsetSpec != null) || - query.fromParameters.GetBindings().Any(b => b.ParameterDefinition != null) || - ((query.selectClause != null) && (query.selectClause.HasDistinct || this.HasSelectAggregate())); + query.FromParameters.GetBindings().Any(b => b.ParameterDefinition != null) || + ((query.selectClause != null) && (query.selectClause.HasDistinct || + this.HasSelectAggregate())); parentQuery = query; } @@ -272,7 +292,7 @@ namespace Microsoft.Azure.Cosmos.Linq private QueryUnderConstruction Flatten() { // SELECT fo(y) FROM y IN (SELECT fi(x) FROM x WHERE gi(x)) WHERE go(y) - // is translated by substituting fi(x) for y in the outer query + // is translated by substituting y for fi(x) in the outer query // producing // SELECT fo(fi(x)) FROM x WHERE gi(x) AND (go(fi(x)) if (this.inputQuery == null) @@ -281,7 +301,8 @@ namespace Microsoft.Azure.Cosmos.Linq if (this.selectClause == null) { // If selectClause doesn't exists, use SELECT v0 where v0 is the input parameter, instead of SELECT *. - string parameterName = this.fromParameters.GetInputParameter().Name; + // If there is a groupby clause, the input parameter comes from the groupBy binding instead of the from clause binding + string parameterName = (this.GroupByParameter ?? this.FromParameters).GetInputParameter().Name; SqlScalarExpression parameterExpression = SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterName)); this.selectClause = SqlSelectClause.Create(SqlSelectValueSpec.Create(parameterExpression)); } @@ -302,12 +323,12 @@ namespace Microsoft.Azure.Cosmos.Linq // That is because if it has been binded before, it has global scope and should not be replaced. string paramName = null; HashSet inputQueryParams = new HashSet(); - foreach (Binding binding in this.inputQuery.fromParameters.GetBindings()) + foreach (Binding binding in this.inputQuery.FromParameters.GetBindings()) { inputQueryParams.Add(binding.Parameter.Name); } - foreach (Binding binding in this.fromParameters.GetBindings()) + foreach (Binding binding in this.FromParameters.GetBindings()) { if (binding.ParameterDefinition == null || inputQueryParams.Contains(binding.Parameter.Name)) { @@ -316,11 +337,14 @@ namespace Microsoft.Azure.Cosmos.Linq } SqlIdentifier replacement = SqlIdentifier.Create(paramName); - SqlSelectClause composedSelect = this.Substitute(inputSelect, inputSelect.TopSpec ?? this.topSpec, replacement, this.selectClause); + SqlSelectClause composedSelect; + + composedSelect = this.Substitute(inputSelect, inputSelect.TopSpec ?? this.topSpec, replacement, this.selectClause); SqlWhereClause composedWhere = this.Substitute(inputSelect.SelectSpec, replacement, this.whereClause); SqlOrderByClause composedOrderBy = this.Substitute(inputSelect.SelectSpec, replacement, this.orderByClause); + SqlGroupByClause composedGroupBy = this.Substitute(inputSelect.SelectSpec, replacement, this.groupByClause); SqlWhereClause and = QueryUnderConstruction.CombineWithConjunction(inputwhere, composedWhere); - FromParameterBindings fromParams = QueryUnderConstruction.CombineInputParameters(flatInput.fromParameters, this.fromParameters); + FromParameterBindings fromParams = QueryUnderConstruction.CombineInputParameters(flatInput.FromParameters, this.FromParameters); SqlOffsetSpec offsetSpec; SqlLimitSpec limitSpec; if (flatInput.offsetSpec != null) @@ -338,8 +362,9 @@ namespace Microsoft.Azure.Cosmos.Linq selectClause = composedSelect, whereClause = and, inputQuery = null, - fromParameters = flatInput.fromParameters, + FromParameters = flatInput.FromParameters, orderByClause = composedOrderBy ?? this.inputQuery.orderByClause, + groupByClause = composedGroupBy ?? this.inputQuery.groupByClause, offsetSpec = offsetSpec, limitSpec = limitSpec, alias = new Lazy(() => this.Alias) @@ -349,25 +374,25 @@ namespace Microsoft.Azure.Cosmos.Linq private SqlSelectClause Substitute(SqlSelectClause inputSelectClause, SqlTopSpec topSpec, SqlIdentifier inputParam, SqlSelectClause selectClause) { - SqlSelectSpec selectSpec = inputSelectClause.SelectSpec; + SqlSelectSpec inputSelectSpec = inputSelectClause.SelectSpec; if (selectClause == null) { - return selectSpec != null ? SqlSelectClause.Create(selectSpec, topSpec, inputSelectClause.HasDistinct) : null; + return inputSelectSpec != null ? SqlSelectClause.Create(inputSelectSpec, topSpec, inputSelectClause.HasDistinct) : null; } - if (selectSpec is SqlSelectStarSpec) + if (inputSelectSpec is SqlSelectStarSpec) { - return SqlSelectClause.Create(selectSpec, topSpec, inputSelectClause.HasDistinct); + return SqlSelectClause.Create(inputSelectSpec, topSpec, inputSelectClause.HasDistinct); } - SqlSelectValueSpec selValue = selectSpec as SqlSelectValueSpec; + SqlSelectValueSpec selValue = inputSelectSpec as SqlSelectValueSpec; if (selValue != null) { SqlSelectSpec intoSpec = selectClause.SelectSpec; if (intoSpec is SqlSelectStarSpec) { - return SqlSelectClause.Create(selectSpec, topSpec, selectClause.HasDistinct || inputSelectClause.HasDistinct); + return SqlSelectClause.Create(inputSelectSpec, topSpec, selectClause.HasDistinct || inputSelectClause.HasDistinct); } SqlSelectValueSpec intoSelValue = intoSpec as SqlSelectValueSpec; @@ -381,7 +406,7 @@ namespace Microsoft.Azure.Cosmos.Linq throw new DocumentQueryException("Unexpected SQL select clause type: " + intoSpec.GetType()); } - throw new DocumentQueryException("Unexpected SQL select clause type: " + selectSpec.GetType()); + throw new DocumentQueryException("Unexpected SQL select clause type: " + inputSelectSpec.GetType()); } private SqlWhereClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam, SqlWhereClause whereClause) @@ -440,6 +465,30 @@ namespace Microsoft.Azure.Cosmos.Linq throw new DocumentQueryException("Unexpected SQL select clause type: " + spec.GetType()); } + private SqlGroupByClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam, SqlGroupByClause groupByClause) + { + if (groupByClause == null) + { + return null; + } + + SqlSelectValueSpec selectValueSpec = spec as SqlSelectValueSpec; + if (selectValueSpec != null) + { + SqlScalarExpression replaced = selectValueSpec.Expression; + SqlScalarExpression[] substitutedItems = new SqlScalarExpression[groupByClause.Expressions.Length]; + for (int i = 0; i < substitutedItems.Length; ++i) + { + SqlScalarExpression substituted = SqlExpressionManipulation.Substitute(replaced, inputParam, groupByClause.Expressions[i]); + substitutedItems[i] = substituted; + } + SqlGroupByClause result = SqlGroupByClause.Create(substitutedItems); + return result; + } + + throw new DocumentQueryException("Unexpected SQL select clause type: " + spec.GetType()); + } + /// /// Determine if the current method call should create a new QueryUnderConstruction node or not. /// @@ -449,10 +498,14 @@ namespace Microsoft.Azure.Cosmos.Linq public bool ShouldBeOnNewQuery(string methodName, int argumentCount) { // In the LINQ provider perspective, a SQL query (without subquery) the order of the execution of the operations is: - // Join -> Where -> Order By -> Aggregates/Distinct/Select -> Top/Offset Limit + // Join -> Where -> Order By -> Aggregates/Distinct/Select -> Top/Offset Limit + // | | + // |-> Group By->| // // The order for the corresponding LINQ operations is: - // SelectMany -> Where -> OrderBy -> Aggregates/Distinct/Select -> Skip/Take + // SelectMany -> Where -> OrderBy -> Aggregates/Distinct/Select -> Skip/Take + // | | + // |-> Group By->| // // In general, if an operation Op1 is being visited and the current query already has Op0 which // appear not before Op1 in the execution order, then this Op1 needs to be in a new query. This ensures @@ -495,7 +548,7 @@ namespace Microsoft.Azure.Cosmos.Linq break; case LinqMethods.Where: - // Where expression parameter needs to be substitued if necessary so + // Where expression parameter needs to be substituted if necessary so // It is not needed in Select distinct because the Select distinct would have the necessary parameter name adjustment. case LinqMethods.Any: case LinqMethods.OrderBy: @@ -506,7 +559,16 @@ namespace Microsoft.Azure.Cosmos.Linq // New query is needed when there is already a Take or a non-distinct Select shouldPackage = (this.topSpec != null) || (this.offsetSpec != null) || - (this.selectClause != null && !this.selectClause.HasDistinct); + (this.selectClause != null && !this.selectClause.HasDistinct) || + (this.groupByClause != null); + break; + + case LinqMethods.GroupBy: + // New query is needed when there is already a Take or a Select or a Group by clause + shouldPackage = (this.topSpec != null) || + (this.offsetSpec != null) || + (this.selectClause != null) || + (this.groupByClause != null); break; case LinqMethods.Skip: @@ -592,6 +654,16 @@ namespace Microsoft.Azure.Cosmos.Linq return context.CurrentQuery; } + public QueryUnderConstruction AddGroupByClause(SqlGroupByClause groupBy, TranslationContext context) + { + QueryUnderConstruction result = context.PackageCurrentQueryIfNeccessary(); + + result.groupByClause = groupBy; + foreach (Binding binding in context.CurrentSubqueryBinding.TakeBindings()) result.AddBinding(binding); + + return result; + } + public QueryUnderConstruction AddOffsetSpec(SqlOffsetSpec offsetSpec, TranslationContext context) { QueryUnderConstruction result = context.PackageCurrentQueryIfNeccessary(); @@ -826,6 +898,7 @@ namespace Microsoft.Azure.Cosmos.Linq private bool HasSelectAggregate() { string functionCallName = ((this.selectClause?.SelectSpec as SqlSelectValueSpec)?.Expression as SqlFunctionCallScalarExpression)?.Name.Value; + return (functionCallName != null) && ((functionCallName == SqlFunctionCallScalarExpression.Names.Max) || (functionCallName == SqlFunctionCallScalarExpression.Names.Min) || diff --git a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs index 82150b6d4..606ade1d8 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs @@ -43,6 +43,16 @@ namespace Microsoft.Azure.Cosmos.Linq /// public IDictionary Parameters; + /// + /// Dictionary for group by key substitution. + /// + public ParameterSubstitution GroupByKeySubstitution; + + /// + /// Boolean to indicate a GroupBy expression is the last expression to finished processing. + /// + public bool LastExpressionIsGroupBy; + /// /// If the FROM clause uses a parameter name, it will be substituted for the parameter used in /// the lambda expressions for the WHERE and SELECT clauses. @@ -86,6 +96,7 @@ namespace Microsoft.Azure.Cosmos.Linq this.subqueryBindingStack = new Stack(); this.Parameters = parameters; this.clientOperation = null; + this.LastExpressionIsGroupBy = false; if (linqSerializerOptionsInternal?.CustomCosmosLinqSerializer != null) { @@ -104,6 +115,8 @@ namespace Microsoft.Azure.Cosmos.Linq this.CosmosLinqSerializer = TranslationContext.DefaultLinqSerializer; this.MemberNames = TranslationContext.DefaultMemberNames; } + + this.GroupByKeySubstitution = new ParameterSubstitution(); } public ScalarOperationKind ClientOperation => this.clientOperation ?? ScalarOperationKind.None; @@ -120,17 +133,25 @@ namespace Microsoft.Azure.Cosmos.Linq public Expression LookupSubstitution(ParameterExpression parameter) { + if (this.CurrentQuery.GroupByParameter != null) + { + Expression groupBySubstitutionExpression = this.GroupByKeySubstitution.Lookup(parameter); + if (groupBySubstitutionExpression != null) + { + return groupBySubstitutionExpression; + } + } return this.substitutions.Lookup(parameter); } - public ParameterExpression GenFreshParameter(Type parameterType, string baseParameterName) + public ParameterExpression GenerateFreshParameter(Type parameterType, string baseParameterName, bool includeSuffix = true) { - return Utilities.NewParameter(baseParameterName, parameterType, this.InScope); + return Utilities.NewParameter(baseParameterName, parameterType, this.InScope, includeSuffix); } public Func GetGenFreshParameterFunc() { - return (paramName) => this.GenFreshParameter(typeof(object), paramName); + return (paramName) => this.GenerateFreshParameter(typeof(object), paramName); } /// @@ -211,12 +232,12 @@ namespace Microsoft.Azure.Cosmos.Linq throw new ArgumentNullException("collection"); } - this.collectionStack.Add(collection); + if (this.CurrentQuery.GroupByParameter == null) this.collectionStack.Add(collection); } public void PopCollection() { - this.collectionStack.RemoveAt(this.collectionStack.Count - 1); + if (this.CurrentQuery.GroupByParameter == null) this.collectionStack.RemoveAt(this.collectionStack.Count - 1); } /// @@ -226,7 +247,7 @@ namespace Microsoft.Azure.Cosmos.Linq /// Suggested name for the input parameter. public ParameterExpression SetInputParameter(Type type, string name) { - return this.CurrentQuery.fromParameters.SetInputParameter(type, name, this.InScope); + return this.CurrentQuery.FromParameters.SetInputParameter(type, name, this.InScope); } /// @@ -237,7 +258,7 @@ namespace Microsoft.Azure.Cosmos.Linq public void SetFromParameter(ParameterExpression parameter, SqlCollection collection) { Binding binding = new Binding(parameter, collection, isInCollection: true); - this.CurrentQuery.fromParameters.Add(binding); + this.CurrentQuery.FromParameters.Add(binding); } /// diff --git a/Microsoft.Azure.Cosmos/src/Linq/Utilities.cs b/Microsoft.Azure.Cosmos/src/Linq/Utilities.cs index f7870d208..ca39671f7 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/Utilities.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/Utilities.cs @@ -45,15 +45,16 @@ namespace Microsoft.Azure.Cosmos.Linq /// Prefix for the parameter name. /// Parameter type. /// Names to avoid. + /// Enable suffix to parameter name /// The new parameter. - public static ParameterExpression NewParameter(string prefix, Type type, HashSet inScope) + public static ParameterExpression NewParameter(string prefix, Type type, HashSet inScope, bool includeSuffix = true) { int suffix = 0; while (true) { - string name = prefix + suffix.ToString(CultureInfo.InvariantCulture); + string name = prefix + (includeSuffix ? suffix.ToString(CultureInfo.InvariantCulture) : string.Empty); ParameterExpression param = Expression.Parameter(type, name); - if (!inScope.Any(p => p.Name.Equals(name))) + if (!inScope.Any(p => p.Name.Equals(name)) || !includeSuffix) { inScope.Add(param); return param; diff --git a/Microsoft.Azure.Cosmos/src/SqlObjects/Visitors/SqlObjectTextSerializer.cs b/Microsoft.Azure.Cosmos/src/SqlObjects/Visitors/SqlObjectTextSerializer.cs index a29435744..150be1427 100644 --- a/Microsoft.Azure.Cosmos/src/SqlObjects/Visitors/SqlObjectTextSerializer.cs +++ b/Microsoft.Azure.Cosmos/src/SqlObjects/Visitors/SqlObjectTextSerializer.cs @@ -495,6 +495,7 @@ namespace Microsoft.Azure.Cosmos.SqlObjects.Visitors if (sqlQuery.GroupByClause != null) { + this.WriteDelimiter(string.Empty); sqlQuery.GroupByClause.Accept(this); this.writer.Write(" "); } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqGeneralBaselineTests.TestGroupByTranslation.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqGeneralBaselineTests.TestGroupByTranslation.xml new file mode 100644 index 000000000..f5256c787 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqGeneralBaselineTests.TestGroupByTranslation.xml @@ -0,0 +1,449 @@ + + + + + k, (key, values) => key)]]> + + + + + + + + + k.Id, (key, values) => key)]]> + + + + + + + + + k.Id, (stringField, values) => stringField)]]> + + + + + + + + + k.Id, (key, values) => values.Min(value => value.Int))]]> + + + + + + + + + k.Id, (key, values) => values.Max(value => value.Int))]]> + + + + + + + + + k.Id, (key, values) => values.Count())]]> + + + + + + + + + k.Id, (key, values) => values.Average(value => value.Int))]]> + + + + + + + + + k.Int, (key, values) => values.Min())]]> + + + + + + + + + + k.Int, (key, values) => values.Max())]]> + + + + + + + + + + k.Int, (key, values) => "string")]]> + + + + + + + + + k.Id)]]> + + + + + + + + + + k.Int, k2 => k2.Int, (key, values) => "string")]]> + + + + + + + + + + k.Id, (key, values) => values.Select(value => value.Int))]]> + + + + + + + + + + k.Id, (key, values) => values.OrderBy(f => f.FamilyId))]]> + + + + + + + + + + k.FamilyId, (key, values) => new AnonymousType(familyId = key, familyIdCount = values.Count()))]]> + + + + + + + + + + x.Id).GroupBy(k => k, (key, values) => key)]]> + + + + + + + + + new AnonymousType(Id1 = x.Id, family1 = x.FamilyId, childrenN1 = x.Children)).GroupBy(k => k.family1, (key, values) => key)]]> + + + + + + + + + x.Children).GroupBy(k => k.Grade, (key, values) => key)]]> + + + + + + + + + f.Children).Where(c => (c.Pets.Count() > 0)).SelectMany(c => c.Pets.Select(p => p.GivenName)).GroupBy(k => k, (key, values) => key)]]> + + + 0)) AS r0 + GROUP BY r0 +]]> + + + + + + k.Id, (key, values) => key)]]> + + + + + + + + + + k.Id, (key, values) => key)]]> + + + + + + + + + + k.Id, (key, values) => key)]]> + + + + + + + + + + (x.Id != "a")).GroupBy(k => k.Id, (key, values) => key)]]> + + + + + + + + + x.Int).GroupBy(k => k.Id, (key, values) => key)]]> + + + + + + + + + + x.Id).GroupBy(k => k.Id, (key, values) => key)]]> + + + + + + + + + (x.Id != "a")).OrderBy(x => x.Id).GroupBy(k => k.Id, (key, values) => key)]]> + + + + + + + + + (x.Id != "a")).Where(x => (x.Children.Min(y => y.Grade) > 10)).GroupBy(k => k.Id, (key, values) => key)]]> + + + 10)) + GROUP BY root["id"] +]]> + + + + + + + k.Id, (key, values) => key).Select(x => x)]]> + + + + + + + + + + k.Id, (key, values) => key).Skip(10)]]> + + + + + + + + + + k.Id, (key, values) => key).Take(10)]]> + + + + + + + + + + k.Id, (key, values) => key).Skip(10).Take(10)]]> + + + + + + + + + + k.Id, (key, values) => key).Where(x => (x == "a"))]]> + + + + + + + + + + k.Id, (key, values) => key).OrderBy(x => x)]]> + + + + + + + + + + k.Id, (key, values) => key).OrderByDescending(x => x)]]> + + + + + + + + + + k.Id, (key, values) => key).Where(x => (x == "a")).Skip(10).Take(10)]]> + + + + + + + + + + k.Id, (key, values) => key).GroupBy(k => k, (key, values) => key)]]> + + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqGeneralBaselineTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqGeneralBaselineTests.cs index e76b2827b..21172812f 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqGeneralBaselineTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqGeneralBaselineTests.cs @@ -697,6 +697,155 @@ namespace Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests this.ExecuteTestSuite(inputs); } + [TestMethod] + public void TestGroupByTranslation() + { + List inputs = new List(); + inputs.Add(new LinqTestInput("GroupBy Single Value Select Key", b => getQuery(b).GroupBy(k => k /*keySelector*/, + (key, values) => key /*return the group by key */))); + inputs.Add(new LinqTestInput("GroupBy Single Value Select Key", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => key /*return the group by key */))); + inputs.Add(new LinqTestInput("GroupBy Single Value Select Key Alias", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (stringField, values) => stringField /*return the group by key */))); + + + inputs.Add(new LinqTestInput("GroupBy Single Value With Min", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => values.Min(value => value.Int) /*return the Min of each group */))); + inputs.Add(new LinqTestInput("GroupBy Single Value With Max", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => values.Max(value => value.Int) /*return the Max of each group */))); + inputs.Add(new LinqTestInput("GroupBy Single Value With Count", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => values.Count() /*return the Count of each group */))); + inputs.Add(new LinqTestInput("GroupBy Single Value With Average", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => values.Average(value => value.Int) /*return the Count of each group */))); + + // Negative cases + + // The translation is correct (SELECT VALUE MIN(root) FROM root GROUP BY root["Number"] + // but the behavior between LINQ and SQL is different + // In Linq, it requires the object to have comparer traits, where as in CosmosDB, we will return null + inputs.Add(new LinqTestInput("GroupBy Single Value With Min", b => getQuery(b).GroupBy(k => k.Int /*keySelector*/, + (key, values) => values.Min() /*return the Min of each group */))); + inputs.Add(new LinqTestInput("GroupBy Single Value With Max", b => getQuery(b).GroupBy(k => k.Int /*keySelector*/, + (key, values) => values.Max() /*return the Max of each group */))); + + // Unsupported node type + inputs.Add(new LinqTestInput("GroupBy Single Value With Min", b => getQuery(b).GroupBy(k => k.Int /*keySelector*/, + (key, values) => "string" /* Unsupported Nodetype*/ ))); + + // Incorrect number of arguments + inputs.Add(new LinqTestInput("GroupBy Single Value With Count", b => getQuery(b).GroupBy(k => k.Id))); + inputs.Add(new LinqTestInput("GroupBy Single Value With Min", b => getQuery(b).GroupBy( + k => k.Int, + k2 => k2.Int, + (key, values) => "string" /* Unsupported Nodetype*/ ))); + + // Non-aggregate method calls + inputs.Add(new LinqTestInput("GroupBy Single Value With Count", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => values.Select(value => value.Int) /*Not an aggregate*/))); + inputs.Add(new LinqTestInput("GroupBy Single Value With Count", b => getQuery(b).GroupBy(k => k.Id /*keySelector*/, + (key, values) => values.OrderBy(f => f.FamilyId) /*Not an aggregate*/))); + + // Currently unsupported case + inputs.Add(new LinqTestInput("GroupBy Single Value With Min", b => getQuery(b).GroupBy(k => k.FamilyId /*keySelector*/, + (key, values) => new { familyId = key, familyIdCount = values.Count() } /*multi-value select */))); + + // Other methods followed by GroupBy + + inputs.Add(new LinqTestInput("Select + GroupBy", b => getQuery(b) + .Select(x => x.Id) + .GroupBy(k => k /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("Select + GroupBy 2", b => getQuery(b) + .Select(x => new { Id1 = x.Id, family1 = x.FamilyId, childrenN1 = x.Children }) + .GroupBy(k => k.family1 /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("SelectMany + GroupBy", b => getQuery(b) + .SelectMany(x => x.Children) + .GroupBy(k => k.Grade /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("SelectMany + GroupBy 2", b => getQuery(b) + .SelectMany(f => f.Children) + .Where(c => c.Pets.Count() > 0) + .SelectMany(c => c.Pets.Select(p => p.GivenName)) + .GroupBy(k => k /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("Skip + GroupBy", b => getQuery(b) + .Skip(10) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("Take + GroupBy", b => getQuery(b) + .Take(10) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("Skip + Take + GroupBy", b => getQuery(b) + .Skip(10).Take(10) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("Filter + GroupBy", b => getQuery(b) + .Where(x => x.Id != "a") + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + // should this become a subquery with order by then group by? + inputs.Add(new LinqTestInput("OrderBy + GroupBy", b => getQuery(b) + .OrderBy(x => x.Int) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("OrderBy Descending + GroupBy", b => getQuery(b) + .OrderByDescending(x => x.Id) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + inputs.Add(new LinqTestInput("Combination + GroupBy", b => getQuery(b) + .Where(x => x.Id != "a") + .OrderBy(x => x.Id) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + // The result for this is not correct yet - the select clause is wrong + inputs.Add(new LinqTestInput("Combination 2 + GroupBy", b => getQuery(b) + .Where(x => x.Id != "a") + .Where(x => x.Children.Min(y => y.Grade) > 10) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */))); + + // GroupBy followed by other methods + inputs.Add(new LinqTestInput("GroupBy + Select", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .Select(x => x))); + + //We should support skip take + inputs.Add(new LinqTestInput("GroupBy + Skip", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .Skip(10))); + + inputs.Add(new LinqTestInput("GroupBy + Take", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .Take(10))); + + inputs.Add(new LinqTestInput("GroupBy + Skip + Take", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .Skip(10).Take(10))); + + inputs.Add(new LinqTestInput("GroupBy + Filter", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .Where(x => x == "a"))); + + inputs.Add(new LinqTestInput("GroupBy + OrderBy", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .OrderBy(x => x))); + + inputs.Add(new LinqTestInput("GroupBy + OrderBy Descending", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .OrderByDescending(x => x))); + + inputs.Add(new LinqTestInput("GroupBy + Combination", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .Where(x => x == "a").Skip(10).Take(10))); + + inputs.Add(new LinqTestInput("GroupBy + GroupBy", b => getQuery(b) + .GroupBy(k => k.Id /*keySelector*/, (key, values) => key /*return the group by key */) + .GroupBy(k => k /*keySelector*/, (key, values) => key /*return the group by key */))); + + this.ExecuteTestSuite(inputs); + } + [TestMethod] [Ignore] public void DebuggingTest() diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs index f8fea9959..ab5155cbf 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs @@ -1,574 +1,574 @@ -//----------------------------------------------------------------------- -// -// Copyright (c) Microsoft Corporation. All rights reserved. -// -//----------------------------------------------------------------------- -namespace Microsoft.Azure.Cosmos.Services.Management.Tests -{ - using System; - using System.Collections; - using System.Collections.Generic; - using System.Collections.ObjectModel; - using System.Diagnostics; - using System.IO; - using System.Linq; - using System.Linq.Expressions; - using System.Reflection; - using System.Runtime.CompilerServices; - using System.Text; - using System.Text.Json.Serialization; - using System.Text.Json; - using System.Text.RegularExpressions; - using System.Xml; - using global::Azure.Core.Serialization; - using Microsoft.Azure.Cosmos.Services.Management.Tests.BaselineTest; - using Microsoft.Azure.Documents; - using Microsoft.VisualStudio.TestTools.UnitTesting; - using Newtonsoft.Json; +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Collections.ObjectModel; + using System.Diagnostics; + using System.IO; + using System.Linq; + using System.Linq.Expressions; + using System.Reflection; + using System.Runtime.CompilerServices; + using System.Text; + using System.Text.Json.Serialization; + using System.Text.Json; + using System.Text.RegularExpressions; + using System.Xml; + using global::Azure.Core.Serialization; + using Microsoft.Azure.Cosmos.Services.Management.Tests.BaselineTest; + using Microsoft.Azure.Documents; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Newtonsoft.Json; using Newtonsoft.Json.Linq; - internal class LinqTestsCommon - { - /// - /// Compare two list of anonymous objects - /// - /// - /// - /// - private static bool CompareListOfAnonymousType(List queryResults, List dataResults) - { - return queryResults.SequenceEqual(dataResults); - } - - /// - /// Compare 2 IEnumerable which may contain IEnumerable themselves. - /// - /// The query results from Cosmos DB - /// The query results from actual data - /// True if the two IEbumerable equal - private static bool NestedListsSequenceEqual(IEnumerable queryResults, IEnumerable dataResults) - { - IEnumerator queryIter, dataIter; - for (queryIter = queryResults.GetEnumerator(), dataIter = dataResults.GetEnumerator(); - queryIter.MoveNext() && dataIter.MoveNext();) - { - IEnumerable queryEnumerable = queryIter.Current as IEnumerable; - IEnumerable dataEnumerable = dataIter.Current as IEnumerable; - if (queryEnumerable == null && dataEnumerable == null) - { - if (!queryIter.Current.Equals(dataIter.Current)) return false; - - } - - else if (queryEnumerable == null || dataEnumerable == null) - { - return false; - } - - else - { - if (!LinqTestsCommon.NestedListsSequenceEqual(queryEnumerable, dataEnumerable)) return false; - } - } - - return !(queryIter.MoveNext() || dataIter.MoveNext()); - } - - /// - /// Compare the list of results from CosmosDB query and the list of results from LinQ query on the original data - /// Similar to Collections.SequenceEqual with the assumption that these lists are non-empty - /// - /// A list representing the query restuls from CosmosDB - /// A list representing the linQ query results from the original data - /// true if the two - private static bool CompareListOfArrays(List queryResults, List dataResults) - { - if (NestedListsSequenceEqual(queryResults, dataResults)) return true; - - bool resultMatched = true; - - // dataResults contains type ConcatIterator whereas queryResults may contain IEnumerable - // therefore it's simpler to just cast them into List> manually for simplify the verification - List> l1 = new List>(); - foreach (IEnumerable list in dataResults) - { - List l = new List(); - IEnumerator iterator = list.GetEnumerator(); - while (iterator.MoveNext()) - { - l.Add(iterator.Current); - } - - l1.Add(l); - } - - List> l2 = new List>(); - foreach (IEnumerable list in queryResults) - { - List l = new List(); - IEnumerator iterator = list.GetEnumerator(); - while (iterator.MoveNext()) - { - l.Add(iterator.Current); - } - - l2.Add(l); - } - - foreach (IEnumerable list in l1) - { - if (!l2.Any(a => a.SequenceEqual(list))) - { - resultMatched = false; - return false; - } - } - - foreach (IEnumerable list in l2) - { - if (!l1.Any(a => a.SequenceEqual(list))) - { - resultMatched = false; - break; - } - } - - return resultMatched; - } - - private static bool IsNumber(dynamic value) - { - return value is sbyte - || value is byte - || value is short - || value is ushort - || value is int - || value is uint - || value is long - || value is ulong - || value is float - || value is double - || value is decimal; - } - - public static Boolean IsAnonymousType(Type type) - { - Boolean hasCompilerGeneratedAttribute = type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), false).Count() > 0; - Boolean nameContainsAnonymousType = type.FullName.Contains("AnonymousType"); - Boolean isAnonymousType = hasCompilerGeneratedAttribute && nameContainsAnonymousType; - - return isAnonymousType; - } - - /// - /// Gets the results of CosmosDB query and the results of LINQ query on the original data - /// - /// - /// - public static (List queryResults, List dataResults) GetResults(IQueryable queryResults, IQueryable dataResults) - { - // execution validation - IEnumerator queryEnumerator = queryResults.GetEnumerator(); - List queryResultsList = new List(); - while (queryEnumerator.MoveNext()) - { - queryResultsList.Add(queryEnumerator.Current); - } - - List dataResultsList = dataResults?.Cast()?.ToList(); - - return (queryResultsList, dataResultsList); - } - - /// - /// Validates the results of CosmosDB query and the results of LINQ query on the original data - /// Using Assert, will fail the unit test if the two results list are not SequenceEqual - /// - /// - /// - private static void ValidateResults(List queryResultsList, List dataResultsList) - { - bool resultMatched = true; - string actualStr = null; - string expectedStr = null; - if (dataResultsList.Count == 0 || queryResultsList.Count == 0) - { - resultMatched &= dataResultsList.Count == queryResultsList.Count; - } - else - { - dynamic firstElem = dataResultsList.FirstOrDefault(); - if (firstElem is IEnumerable) - { - resultMatched &= CompareListOfArrays(queryResultsList, dataResultsList); - } - else if (LinqTestsCommon.IsAnonymousType(firstElem.GetType())) - { - resultMatched &= CompareListOfAnonymousType(queryResultsList, dataResultsList); - } - else if (LinqTestsCommon.IsNumber(firstElem)) - { - const double Epsilon = 1E-6; - Type dataType = firstElem.GetType(); - List dataSortedList = dataResultsList.OrderBy(x => x).ToList(); - List querySortedList = queryResultsList.OrderBy(x => x).ToList(); - if (dataSortedList.Count != querySortedList.Count) - { - resultMatched = false; - } - else - { - for (int i = 0; i < dataSortedList.Count; ++i) - { - if (Math.Abs(dataSortedList[i] - (dynamic)querySortedList[i]) > (dynamic)Convert.ChangeType(Epsilon, dataType)) - { - resultMatched = false; - break; - } - } - } - - if (!resultMatched) - { - actualStr = JsonConvert.SerializeObject(querySortedList); - expectedStr = JsonConvert.SerializeObject(dataSortedList); - } - } - else - { - List dataNotQuery = dataResultsList.Except(queryResultsList).ToList(); - List queryNotData = queryResultsList.Except(dataResultsList).ToList(); - resultMatched &= !dataNotQuery.Any() && !queryNotData.Any(); - } - } - - string assertMsg = string.Empty; - if (!resultMatched) - { - actualStr ??= JsonConvert.SerializeObject(queryResultsList); - expectedStr ??= JsonConvert.SerializeObject(dataResultsList); - - resultMatched |= actualStr.Equals(expectedStr); - if (!resultMatched) - { - assertMsg = $"Expected: {expectedStr}, Actual: {actualStr}, RandomSeed: {LinqTestInput.RandomSeed}"; - } - } - - Assert.IsTrue(resultMatched, assertMsg); - } - - /// - /// Generate a random string containing alphabetical characters - /// - /// - /// - /// a random string - public static string RandomString(Random random, int length) - { - const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefghijklmnopqrstuvwxyz "; - return new string(Enumerable.Repeat(chars, length).Select(s => s[random.Next(s.Length)]).ToArray()); - } - - /// - /// Generate a random DateTime object from a DateTime, - /// with the variance of the time span between the provided DateTime to the current time - /// - /// - /// - /// - public static DateTime RandomDateTime(Random random, DateTime midDateTime) - { - TimeSpan timeSpan = DateTime.Now - midDateTime; - TimeSpan newSpan = new TimeSpan(0, random.Next(0, (int)timeSpan.TotalMinutes * 2) - (int)timeSpan.TotalMinutes, 0); - DateTime newDate = midDateTime + newSpan; - return newDate; - } - - /// - /// Generate test data for most LINQ tests - /// - /// the object type - /// the lamda to create an instance of test data - /// number of test data to be created - /// the target container - /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable - public static Func> GenerateTestCosmosData(Func func, int count, Container container) - { - List data = new List(); - int seed = DateTime.Now.Millisecond; - Random random = new Random(seed); - Debug.WriteLine("Random seed: {0}", seed); - LinqTestInput.RandomSeed = seed; - for (int i = 0; i < count; ++i) - { - data.Add(func(random)); - } - - foreach (T obj in data) - { - ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; - } - - FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; - QueryRequestOptions requestOptions = new QueryRequestOptions(); - - IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); - - // To cover both query against backend and queries on the original data using LINQ nicely, - // the LINQ expression should be written once and they should be compiled and executed against the two sources. - // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query - // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values - // to this getQuery method. - IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); - - return getQuery; - } - - /// - /// Generate a non-random payload for serializer LINQ tests. - /// - /// the object type - /// the lamda to create an instance of test data - /// number of test data to be created - /// the target container - /// if theCosmosLinqSerializerOption of camelCaseSerialization should be applied - /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable. - public static Func> GenerateSerializationTestCosmosData(Func func, int count, Container container, CosmosLinqSerializerOptions linqSerializerOptions) - { - List data = new List(); - for (int i = 0; i < count; i++) - { - data.Add(func(i, linqSerializerOptions.PropertyNamingPolicy == CosmosPropertyNamingPolicy.CamelCase)); - } - - foreach (T obj in data) - { - ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; - } - - FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; - QueryRequestOptions requestOptions = new QueryRequestOptions(); - - IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions, linqSerializerOptions: linqSerializerOptions); - - IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); - - return getQuery; - } - - public static Func> GenerateFamilyCosmosData( - Cosmos.Database cosmosDatabase, out Container container) - { - // The test collection should have range index on string properties - // for the orderby tests - PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; - ContainerProperties newCol = new ContainerProperties() - { - Id = Guid.NewGuid().ToString(), - PartitionKey = partitionKeyDefinition, - IndexingPolicy = new Microsoft.Azure.Cosmos.IndexingPolicy() - { - IncludedPaths = new Collection() - { - new Cosmos.IncludedPath() - { - Path = "/*", - Indexes = new System.Collections.ObjectModel.Collection() - { - Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.Number, -1), - Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.String, -1) - } - } - }, - CompositeIndexes = new Collection>() - { - new Collection() - { - new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } - } - } - } - }; - container = cosmosDatabase.CreateContainerAsync(newCol).Result; - const int Records = 100; - const int MaxNameLength = 100; - const int MaxThingStringLength = 50; - const int MaxChild = 5; - const int MaxPets = MaxChild; - const int MaxThings = MaxChild; - const int MaxGrade = 101; - const int MaxTransaction = 20; - const int MaxTransactionMinuteRange = 200; - int MaxTransactionType = Enum.GetValues(typeof(TransactionType)).Length; - Family createDataObj(Random random) - { - Family obj = new Family - { - FamilyId = random.NextDouble() < 0.05 ? "some id" : Guid.NewGuid().ToString(), - IsRegistered = random.NextDouble() < 0.5, - NullableInt = random.NextDouble() < 0.5 ? (int?)random.Next() : null, - Int = random.NextDouble() < 0.5 ? 5 : random.Next(), - Id = Guid.NewGuid().ToString(), - Pk = "Test", - Parents = new Parent[random.Next(2) + 1] - }; - for (int i = 0; i < obj.Parents.Length; ++i) - { - obj.Parents[i] = new Parent() - { - FamilyName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), - GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) - }; - } - - obj.Tags = new string[random.Next(MaxChild)]; - for (int i = 0; i < obj.Tags.Length; ++i) - { - obj.Tags[i] = (i + random.Next(30, 36)).ToString(); - } - - obj.Children = new Child[random.Next(MaxChild)]; - for (int i = 0; i < obj.Children.Length; ++i) - { - obj.Children[i] = new Child() - { - Gender = random.NextDouble() < 0.5 ? "male" : "female", - FamilyName = obj.Parents[random.Next(obj.Parents.Length)].FamilyName, - GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), - Grade = random.Next(MaxGrade) - }; - - obj.Children[i].Pets = new List(); - for (int j = 0; j < random.Next(MaxPets); ++j) - { - obj.Children[i].Pets.Add(new Pet() - { - GivenName = random.NextDouble() < 0.5 ? - LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) : - "Fluffy" - }); - } - - obj.Children[i].Things = new Dictionary(); - for (int j = 0; j < random.Next(MaxThings) + 1; ++j) - { - obj.Children[i].Things.Add( - j == 0 ? "A" : $"{j}-{random.Next()}", - LinqTestsCommon.RandomString(random, random.Next(MaxThingStringLength))); - } - } - - obj.Records = new Logs - { - LogId = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), - Transactions = new Transaction[random.Next(MaxTransaction)] - }; - for (int i = 0; i < obj.Records.Transactions.Length; ++i) - { - Transaction transaction = new Transaction() - { - Amount = random.Next(), - Date = DateTime.Now.AddMinutes(random.Next(MaxTransactionMinuteRange)), - Type = (TransactionType)random.Next(MaxTransactionType) - }; - obj.Records.Transactions[i] = transaction; - } - - return obj; - } - - Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, container); - return getQuery; - } - + internal class LinqTestsCommon + { + /// + /// Compare two list of anonymous objects + /// + /// + /// + /// + private static bool CompareListOfAnonymousType(List queryResults, List dataResults) + { + return queryResults.SequenceEqual(dataResults); + } + + /// + /// Compare 2 IEnumerable which may contain IEnumerable themselves. + /// + /// The query results from Cosmos DB + /// The query results from actual data + /// True if the two IEbumerable equal + private static bool NestedListsSequenceEqual(IEnumerable queryResults, IEnumerable dataResults) + { + IEnumerator queryIter, dataIter; + for (queryIter = queryResults.GetEnumerator(), dataIter = dataResults.GetEnumerator(); + queryIter.MoveNext() && dataIter.MoveNext();) + { + IEnumerable queryEnumerable = queryIter.Current as IEnumerable; + IEnumerable dataEnumerable = dataIter.Current as IEnumerable; + if (queryEnumerable == null && dataEnumerable == null) + { + if (!queryIter.Current.Equals(dataIter.Current)) return false; + + } + + else if (queryEnumerable == null || dataEnumerable == null) + { + return false; + } + + else + { + if (!LinqTestsCommon.NestedListsSequenceEqual(queryEnumerable, dataEnumerable)) return false; + } + } + + return !(queryIter.MoveNext() || dataIter.MoveNext()); + } + + /// + /// Compare the list of results from CosmosDB query and the list of results from LinQ query on the original data + /// Similar to Collections.SequenceEqual with the assumption that these lists are non-empty + /// + /// A list representing the query restuls from CosmosDB + /// A list representing the linQ query results from the original data + /// true if the two + private static bool CompareListOfArrays(List queryResults, List dataResults) + { + if (NestedListsSequenceEqual(queryResults, dataResults)) return true; + + bool resultMatched = true; + + // dataResults contains type ConcatIterator whereas queryResults may contain IEnumerable + // therefore it's simpler to just cast them into List> manually for simplify the verification + List> l1 = new List>(); + foreach (IEnumerable list in dataResults) + { + List l = new List(); + IEnumerator iterator = list.GetEnumerator(); + while (iterator.MoveNext()) + { + l.Add(iterator.Current); + } + + l1.Add(l); + } + + List> l2 = new List>(); + foreach (IEnumerable list in queryResults) + { + List l = new List(); + IEnumerator iterator = list.GetEnumerator(); + while (iterator.MoveNext()) + { + l.Add(iterator.Current); + } + + l2.Add(l); + } + + foreach (IEnumerable list in l1) + { + if (!l2.Any(a => a.SequenceEqual(list))) + { + resultMatched = false; + return false; + } + } + + foreach (IEnumerable list in l2) + { + if (!l1.Any(a => a.SequenceEqual(list))) + { + resultMatched = false; + break; + } + } + + return resultMatched; + } + + private static bool IsNumber(dynamic value) + { + return value is sbyte + || value is byte + || value is short + || value is ushort + || value is int + || value is uint + || value is long + || value is ulong + || value is float + || value is double + || value is decimal; + } + + public static Boolean IsAnonymousType(Type type) + { + Boolean hasCompilerGeneratedAttribute = type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), false).Count() > 0; + Boolean nameContainsAnonymousType = type.FullName.Contains("AnonymousType"); + Boolean isAnonymousType = hasCompilerGeneratedAttribute && nameContainsAnonymousType; + + return isAnonymousType; + } + + /// + /// Gets the results of CosmosDB query and the results of LINQ query on the original data + /// + /// + /// + public static (List queryResults, List dataResults) GetResults(IQueryable queryResults, IQueryable dataResults) + { + // execution validation + IEnumerator queryEnumerator = queryResults.GetEnumerator(); + List queryResultsList = new List(); + while (queryEnumerator.MoveNext()) + { + queryResultsList.Add(queryEnumerator.Current); + } + + List dataResultsList = dataResults?.Cast()?.ToList(); + + return (queryResultsList, dataResultsList); + } + + /// + /// Validates the results of CosmosDB query and the results of LINQ query on the original data + /// Using Assert, will fail the unit test if the two results list are not SequenceEqual + /// + /// + /// + private static void ValidateResults(List queryResultsList, List dataResultsList) + { + bool resultMatched = true; + string actualStr = null; + string expectedStr = null; + if (dataResultsList.Count == 0 || queryResultsList.Count == 0) + { + resultMatched &= dataResultsList.Count == queryResultsList.Count; + } + else + { + dynamic firstElem = dataResultsList.FirstOrDefault(); + if (firstElem is IEnumerable) + { + resultMatched &= CompareListOfArrays(queryResultsList, dataResultsList); + } + else if (LinqTestsCommon.IsAnonymousType(firstElem.GetType())) + { + resultMatched &= CompareListOfAnonymousType(queryResultsList, dataResultsList); + } + else if (LinqTestsCommon.IsNumber(firstElem)) + { + const double Epsilon = 1E-6; + Type dataType = firstElem.GetType(); + List dataSortedList = dataResultsList.OrderBy(x => x).ToList(); + List querySortedList = queryResultsList.OrderBy(x => x).ToList(); + if (dataSortedList.Count != querySortedList.Count) + { + resultMatched = false; + } + else + { + for (int i = 0; i < dataSortedList.Count; ++i) + { + if (Math.Abs(dataSortedList[i] - (dynamic)querySortedList[i]) > (dynamic)Convert.ChangeType(Epsilon, dataType)) + { + resultMatched = false; + break; + } + } + } + + if (!resultMatched) + { + actualStr = JsonConvert.SerializeObject(querySortedList); + expectedStr = JsonConvert.SerializeObject(dataSortedList); + } + } + else + { + List dataNotQuery = dataResultsList.Except(queryResultsList).ToList(); + List queryNotData = queryResultsList.Except(dataResultsList).ToList(); + resultMatched &= !dataNotQuery.Any() && !queryNotData.Any(); + } + } + + string assertMsg = string.Empty; + if (!resultMatched) + { + actualStr ??= JsonConvert.SerializeObject(queryResultsList); + expectedStr ??= JsonConvert.SerializeObject(dataResultsList); + + resultMatched |= actualStr.Equals(expectedStr); + if (!resultMatched) + { + assertMsg = $"Expected: {expectedStr}, Actual: {actualStr}, RandomSeed: {LinqTestInput.RandomSeed}"; + } + } + + Assert.IsTrue(resultMatched, assertMsg); + } + + /// + /// Generate a random string containing alphabetical characters + /// + /// + /// + /// a random string + public static string RandomString(Random random, int length) + { + const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefghijklmnopqrstuvwxyz "; + return new string(Enumerable.Repeat(chars, length).Select(s => s[random.Next(s.Length)]).ToArray()); + } + + /// + /// Generate a random DateTime object from a DateTime, + /// with the variance of the time span between the provided DateTime to the current time + /// + /// + /// + /// + public static DateTime RandomDateTime(Random random, DateTime midDateTime) + { + TimeSpan timeSpan = DateTime.Now - midDateTime; + TimeSpan newSpan = new TimeSpan(0, random.Next(0, (int)timeSpan.TotalMinutes * 2) - (int)timeSpan.TotalMinutes, 0); + DateTime newDate = midDateTime + newSpan; + return newDate; + } + + /// + /// Generate test data for most LINQ tests + /// + /// the object type + /// the lamda to create an instance of test data + /// number of test data to be created + /// the target container + /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable + public static Func> GenerateTestCosmosData(Func func, int count, Container container) + { + List data = new List(); + int seed = DateTime.Now.Millisecond; + Random random = new Random(seed); + Debug.WriteLine("Random seed: {0}", seed); + LinqTestInput.RandomSeed = seed; + for (int i = 0; i < count; ++i) + { + data.Add(func(random)); + } + + foreach (T obj in data) + { + ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; + } + + FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; + QueryRequestOptions requestOptions = new QueryRequestOptions(); + + IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); + + // To cover both query against backend and queries on the original data using LINQ nicely, + // the LINQ expression should be written once and they should be compiled and executed against the two sources. + // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query + // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values + // to this getQuery method. + IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); + + return getQuery; + } + + /// + /// Generate a non-random payload for serializer LINQ tests. + /// + /// the object type + /// the lamda to create an instance of test data + /// number of test data to be created + /// the target container + /// if theCosmosLinqSerializerOption of camelCaseSerialization should be applied + /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable. + public static Func> GenerateSerializationTestCosmosData(Func func, int count, Container container, CosmosLinqSerializerOptions linqSerializerOptions) + { + List data = new List(); + for (int i = 0; i < count; i++) + { + data.Add(func(i, linqSerializerOptions.PropertyNamingPolicy == CosmosPropertyNamingPolicy.CamelCase)); + } + + foreach (T obj in data) + { + ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; + } + + FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; + QueryRequestOptions requestOptions = new QueryRequestOptions(); + + IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions, linqSerializerOptions: linqSerializerOptions); + + IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); + + return getQuery; + } + + public static Func> GenerateFamilyCosmosData( + Cosmos.Database cosmosDatabase, out Container container) + { + // The test collection should have range index on string properties + // for the orderby tests + PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; + ContainerProperties newCol = new ContainerProperties() + { + Id = Guid.NewGuid().ToString(), + PartitionKey = partitionKeyDefinition, + IndexingPolicy = new Microsoft.Azure.Cosmos.IndexingPolicy() + { + IncludedPaths = new Collection() + { + new Cosmos.IncludedPath() + { + Path = "/*", + Indexes = new System.Collections.ObjectModel.Collection() + { + Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.Number, -1), + Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.String, -1) + } + } + }, + CompositeIndexes = new Collection>() + { + new Collection() + { + new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } + } + } + } + }; + container = cosmosDatabase.CreateContainerAsync(newCol).Result; + const int Records = 100; + const int MaxNameLength = 100; + const int MaxThingStringLength = 50; + const int MaxChild = 5; + const int MaxPets = MaxChild; + const int MaxThings = MaxChild; + const int MaxGrade = 101; + const int MaxTransaction = 20; + const int MaxTransactionMinuteRange = 200; + int MaxTransactionType = Enum.GetValues(typeof(TransactionType)).Length; + Family createDataObj(Random random) + { + Family obj = new Family + { + FamilyId = random.NextDouble() < 0.05 ? "some id" : Guid.NewGuid().ToString(), + IsRegistered = random.NextDouble() < 0.5, + NullableInt = random.NextDouble() < 0.5 ? (int?)random.Next() : null, + Int = random.NextDouble() < 0.5 ? 5 : random.Next(), + Id = Guid.NewGuid().ToString(), + Pk = "Test", + Parents = new Parent[random.Next(2) + 1] + }; + for (int i = 0; i < obj.Parents.Length; ++i) + { + obj.Parents[i] = new Parent() + { + FamilyName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), + GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) + }; + } + + obj.Tags = new string[random.Next(MaxChild)]; + for (int i = 0; i < obj.Tags.Length; ++i) + { + obj.Tags[i] = (i + random.Next(30, 36)).ToString(); + } + + obj.Children = new Child[random.Next(MaxChild)]; + for (int i = 0; i < obj.Children.Length; ++i) + { + obj.Children[i] = new Child() + { + Gender = random.NextDouble() < 0.5 ? "male" : "female", + FamilyName = obj.Parents[random.Next(obj.Parents.Length)].FamilyName, + GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), + Grade = random.Next(MaxGrade) + }; + + obj.Children[i].Pets = new List(); + for (int j = 0; j < random.Next(MaxPets); ++j) + { + obj.Children[i].Pets.Add(new Pet() + { + GivenName = random.NextDouble() < 0.5 ? + LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) : + "Fluffy" + }); + } + + obj.Children[i].Things = new Dictionary(); + for (int j = 0; j < random.Next(MaxThings) + 1; ++j) + { + obj.Children[i].Things.Add( + j == 0 ? "A" : $"{j}-{random.Next()}", + LinqTestsCommon.RandomString(random, random.Next(MaxThingStringLength))); + } + } + + obj.Records = new Logs + { + LogId = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), + Transactions = new Transaction[random.Next(MaxTransaction)] + }; + for (int i = 0; i < obj.Records.Transactions.Length; ++i) + { + Transaction transaction = new Transaction() + { + Amount = random.Next(), + Date = DateTime.Now.AddMinutes(random.Next(MaxTransactionMinuteRange)), + Type = (TransactionType)random.Next(MaxTransactionType) + }; + obj.Records.Transactions[i] = transaction; + } + + return obj; + } + + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, container); + return getQuery; + } + public static Func> GenerateSimpleCosmosData(Cosmos.Database cosmosDatabase, bool useRandomData = true) - { - const int DocumentCount = 10; - PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; - Container container = cosmosDatabase.CreateContainerAsync(new ContainerProperties { Id = Guid.NewGuid().ToString(), PartitionKey = partitionKeyDefinition }).Result; - + { + const int DocumentCount = 10; + PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; + Container container = cosmosDatabase.CreateContainerAsync(new ContainerProperties { Id = Guid.NewGuid().ToString(), PartitionKey = partitionKeyDefinition }).Result; + ILinqTestDataGenerator dataGenerator = useRandomData ? new LinqTestRandomDataGenerator(DocumentCount) : new LinqTestDataGenerator(DocumentCount); List testData = new List(dataGenerator.GenerateData()); foreach (Data dataEntry in testData) - { - Data response = container.CreateItemAsync(dataEntry, new Cosmos.PartitionKey(dataEntry.Pk)).Result; - } - - FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; - QueryRequestOptions requestOptions = new QueryRequestOptions(); - - IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); - - // To cover both query against backend and queries on the original data using LINQ nicely, - // the LINQ expression should be written once and they should be compiled and executed against the two sources. - // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query - // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values - // to this getQuery method. - IQueryable getQuery(bool useQuery) => useQuery ? query : testData.AsQueryable(); - return getQuery; - } - - public static LinqTestOutput ExecuteTest(LinqTestInput input, bool serializeResultsInBaseline = false) - { - string querySqlStr = string.Empty; - try - { - Func compiledQuery = input.Expression.Compile(); - - IQueryable query = compiledQuery(true); - querySqlStr = JObject.Parse(query.ToString()).GetValue("query", StringComparison.Ordinal).ToString(); - - IQueryable dataQuery = input.skipVerification ? null : compiledQuery(false); - - (List queryResults, List dataResults) = GetResults(query, dataQuery); - - // we skip unordered query because the LINQ results vs actual query results are non-deterministic - if (!input.skipVerification) - { - LinqTestsCommon.ValidateResults(queryResults, dataResults); - } - - string serializedResults = serializeResultsInBaseline ? - JsonConvert.SerializeObject(queryResults.Select(item => item is LinqTestObject ? item.ToString() : item), new JsonSerializerSettings { Formatting = Newtonsoft.Json.Formatting.Indented}) : - null; - - return new LinqTestOutput(querySqlStr, serializedResults, errorMsg: null, input.inputData); - } - catch (Exception e) - { - return new LinqTestOutput(querySqlStr, serializedResults: null, errorMsg: LinqTestsCommon.BuildExceptionMessageForTest(e), inputData: input.inputData); - } - } - - public static string BuildExceptionMessageForTest(Exception ex) + { + Data response = container.CreateItemAsync(dataEntry, new Cosmos.PartitionKey(dataEntry.Pk)).Result; + } + + FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; + QueryRequestOptions requestOptions = new QueryRequestOptions(); + + IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); + + // To cover both query against backend and queries on the original data using LINQ nicely, + // the LINQ expression should be written once and they should be compiled and executed against the two sources. + // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query + // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values + // to this getQuery method. + IQueryable getQuery(bool useQuery) => useQuery ? query : testData.AsQueryable(); + return getQuery; + } + + public static LinqTestOutput ExecuteTest(LinqTestInput input, bool serializeResultsInBaseline = false) + { + string querySqlStr = string.Empty; + try + { + Func compiledQuery = input.Expression.Compile(); + + IQueryable query = compiledQuery(true); + querySqlStr = JObject.Parse(query.ToString()).GetValue("query", StringComparison.Ordinal).ToString(); + + IQueryable dataQuery = input.skipVerification ? null : compiledQuery(false); + + (List queryResults, List dataResults) = GetResults(query, dataQuery); + + // we skip unordered query because the LINQ results vs actual query results are non-deterministic + if (!input.skipVerification) + { + LinqTestsCommon.ValidateResults(queryResults, dataResults); + } + + string serializedResults = serializeResultsInBaseline ? + JsonConvert.SerializeObject(queryResults.Select(item => item is LinqTestObject ? item.ToString() : item), new JsonSerializerSettings { Formatting = Newtonsoft.Json.Formatting.Indented}) : + null; + + return new LinqTestOutput(querySqlStr, serializedResults, errorMsg: null, input.inputData); + } + catch (Exception e) + { + return new LinqTestOutput(querySqlStr, serializedResults: null, errorMsg: LinqTestsCommon.BuildExceptionMessageForTest(e), inputData: input.inputData); + } + } + + public static string BuildExceptionMessageForTest(Exception ex) { StringBuilder message = new StringBuilder(); - do - { - if (ex is CosmosException cosmosException) + do + { + if (ex is CosmosException cosmosException) { // ODE scenario: The backend generates an error response message with significant variations when compared to the Service Interop which gets called in the Non ODE scenario. // The objective is to standardize and normalize the backend response for consistency. @@ -590,290 +590,290 @@ namespace Microsoft.Azure.Cosmos.Services.Management.Tests else { message.Append($"Status Code: {cosmosException.StatusCode}"); - } - } - else if (ex is DocumentClientException documentClientException) - { - message.Append(documentClientException.RawErrorMessage); - } - else + } + } + else if (ex is DocumentClientException documentClientException) + { + message.Append(documentClientException.RawErrorMessage); + } + else { message.Append(ex.Message); - } + } - ex = ex.InnerException; - if (ex != null) - { - message.Append(","); - } - } + ex = ex.InnerException; + if (ex != null) + { + message.Append(","); + } + } while (ex != null); return message.ToString(); - } - } - - /// - /// A base class that determines equality based on its json representation - /// - public class LinqTestObject - { - private string json; - - protected virtual string SerializeForTestBaseline() - { - return JsonConvert.SerializeObject(this); - } - - public override string ToString() - { - // simple cached serialization - this.json ??= this.SerializeForTestBaseline(); - return this.json; - } - - public override bool Equals(object obj) - { - if (!(obj is LinqTestObject && - obj.GetType().IsAssignableFrom(this.GetType()) && - this.GetType().IsAssignableFrom(obj.GetType()))) return false; - if (obj == null) return false; - - return this.ToString().Equals(obj.ToString()); - } - - public override int GetHashCode() - { - return this.ToString().GetHashCode(); - } - } - - public class LinqTestInput : BaselineTestInput - { - internal static Regex classNameRegex = new Regex("(value\\(.+?\\+)?\\<\\>.+?__([A-Za-z]+)((\\d+_\\d+(`\\d+\\[.+?\\])?\\)(\\.value)?)|\\d+`\\d+)"); - internal static Regex invokeCompileRegex = new Regex("(Convert\\()?Invoke\\([^.]+\\.[^.,]+(\\.Compile\\(\\))?, b\\)(\\.Cast\\(\\))?(\\))?"); - - // As the tests are executed sequentially - // We can store the random seed in a static variable for diagnostics - internal static int RandomSeed = -1; - - internal int randomSeed = -1; - internal Expression> Expression { get; } - internal string expressionStr; - internal string inputData; - - // We skip the verification between Cosmos DB and actual query restuls in the following cases - // - unordered query since the results are not deterministics for LinQ results and actual query results - // - scenarios not supported in LINQ, e.g. sequence doesn't contain element. - internal bool skipVerification; - - internal LinqTestInput( - string description, - Expression> expr, - bool skipVerification = false, - string expressionStr = null, - string inputData = null) - : base(description) - { - this.Expression = expr ?? throw new ArgumentNullException($"{nameof(expr)} must not be null."); - this.skipVerification = skipVerification; - this.expressionStr = expressionStr; - this.inputData = inputData; - } - - public static string FilterInputExpression(string input) - { - StringBuilder expressionSb = new StringBuilder(input); - // simplify full qualified class name - // e.g. before: value(Microsoft.Azure.Documents.Services.Management.Tests.LinqSQLTranslationTest+<>c__DisplayClass7_0), after: DisplayClass - // before: <>f__AnonymousType14`2(, after: AnonymousType( - // value(Microsoft.Azure.Documents.Services.Management.Tests.LinqProviderTests.LinqTranslationBaselineTests +<> c__DisplayClass24_0`1[System.String]).value - Match match = classNameRegex.Match(expressionSb.ToString()); - while (match.Success) - { - expressionSb = expressionSb.Replace(match.Groups[0].Value, match.Groups[2].Value); - match = match.NextMatch(); - } - - // remove the Invoke().Compile() string from the Linq scanning tests - match = invokeCompileRegex.Match(expressionSb.ToString()); - while (match.Success) - { - expressionSb = expressionSb.Replace(match.Groups[0].Value, string.Empty); - match = match.NextMatch(); - } - - expressionSb.Insert(0, "query"); - - return expressionSb.ToString(); - } - - public override void SerializeAsXml(XmlWriter xmlWriter) - { - if (xmlWriter == null) - { - throw new ArgumentNullException($"{nameof(xmlWriter)} cannot be null."); - } - - this.expressionStr ??= LinqTestInput.FilterInputExpression(this.Expression.Body.ToString()); - - xmlWriter.WriteStartElement("Description"); - xmlWriter.WriteCData(this.Description); - xmlWriter.WriteEndElement(); - xmlWriter.WriteStartElement("Expression"); - xmlWriter.WriteCData(this.expressionStr); - xmlWriter.WriteEndElement(); - } - } - - public class LinqTestOutput : BaselineTestOutput - { - internal static Regex sdkVersion = new Regex("(,\\W*)?documentdb-dotnet-sdk[^]]+"); - internal static Regex activityId = new Regex("(,\\W*)?ActivityId:.+", RegexOptions.Multiline); - internal static Regex newLine = new Regex("(\r\n|\r|\n)"); - - internal string SqlQuery { get; } - internal string ErrorMessage { get; } - internal string Results { get; } - internal string InputData { get; } - - private static readonly Dictionary newlineKeywords = new Dictionary() { - { "SELECT", "\nSELECT" }, - { "FROM", "\nFROM" }, - { "WHERE", "\nWHERE" }, - { "JOIN", "\nJOIN" }, - { "ORDER BY", "\nORDER BY" }, + } + } + + /// + /// A base class that determines equality based on its json representation + /// + public class LinqTestObject + { + private string json; + + protected virtual string SerializeForTestBaseline() + { + return JsonConvert.SerializeObject(this); + } + + public override string ToString() + { + // simple cached serialization + this.json ??= this.SerializeForTestBaseline(); + return this.json; + } + + public override bool Equals(object obj) + { + if (!(obj is LinqTestObject && + obj.GetType().IsAssignableFrom(this.GetType()) && + this.GetType().IsAssignableFrom(obj.GetType()))) return false; + if (obj == null) return false; + + return this.ToString().Equals(obj.ToString()); + } + + public override int GetHashCode() + { + return this.ToString().GetHashCode(); + } + } + + public class LinqTestInput : BaselineTestInput + { + internal static Regex classNameRegex = new Regex("(value\\(.+?\\+)?\\<\\>.+?__([A-Za-z]+)((\\d+_\\d+(`\\d+\\[.+?\\])?\\)(\\.value)?)|\\d+`\\d+)"); + internal static Regex invokeCompileRegex = new Regex("(Convert\\()?Invoke\\([^.]+\\.[^.,]+(\\.Compile\\(\\))?, b\\)(\\.Cast\\(\\))?(\\))?"); + + // As the tests are executed sequentially + // We can store the random seed in a static variable for diagnostics + internal static int RandomSeed = -1; + + internal int randomSeed = -1; + internal Expression> Expression { get; } + internal string expressionStr; + internal string inputData; + + // We skip the verification between Cosmos DB and actual query restuls in the following cases + // - unordered query since the results are not deterministics for LinQ results and actual query results + // - scenarios not supported in LINQ, e.g. sequence doesn't contain element. + internal bool skipVerification; + + internal LinqTestInput( + string description, + Expression> expr, + bool skipVerification = false, + string expressionStr = null, + string inputData = null) + : base(description) + { + this.Expression = expr ?? throw new ArgumentNullException($"{nameof(expr)} must not be null."); + this.skipVerification = skipVerification; + this.expressionStr = expressionStr; + this.inputData = inputData; + } + + public static string FilterInputExpression(string input) + { + StringBuilder expressionSb = new StringBuilder(input); + // simplify full qualified class name + // e.g. before: value(Microsoft.Azure.Documents.Services.Management.Tests.LinqSQLTranslationTest+<>c__DisplayClass7_0), after: DisplayClass + // before: <>f__AnonymousType14`2(, after: AnonymousType( + // value(Microsoft.Azure.Documents.Services.Management.Tests.LinqProviderTests.LinqTranslationBaselineTests +<> c__DisplayClass24_0`1[System.String]).value + Match match = classNameRegex.Match(expressionSb.ToString()); + while (match.Success) + { + expressionSb = expressionSb.Replace(match.Groups[0].Value, match.Groups[2].Value); + match = match.NextMatch(); + } + + // remove the Invoke().Compile() string from the Linq scanning tests + match = invokeCompileRegex.Match(expressionSb.ToString()); + while (match.Success) + { + expressionSb = expressionSb.Replace(match.Groups[0].Value, string.Empty); + match = match.NextMatch(); + } + + expressionSb.Insert(0, "query"); + + return expressionSb.ToString(); + } + + public override void SerializeAsXml(XmlWriter xmlWriter) + { + if (xmlWriter == null) + { + throw new ArgumentNullException($"{nameof(xmlWriter)} cannot be null."); + } + + this.expressionStr ??= LinqTestInput.FilterInputExpression(this.Expression.Body.ToString()); + + xmlWriter.WriteStartElement("Description"); + xmlWriter.WriteCData(this.Description); + xmlWriter.WriteEndElement(); + xmlWriter.WriteStartElement("Expression"); + xmlWriter.WriteCData(this.expressionStr); + xmlWriter.WriteEndElement(); + } + } + + public class LinqTestOutput : BaselineTestOutput + { + internal static Regex sdkVersion = new Regex("(,\\W*)?documentdb-dotnet-sdk[^]]+"); + internal static Regex activityId = new Regex("(,\\W*)?ActivityId:.+", RegexOptions.Multiline); + internal static Regex newLine = new Regex("(\r\n|\r|\n)"); + + internal string SqlQuery { get; } + internal string ErrorMessage { get; } + internal string Results { get; } + internal string InputData { get; } + + private static readonly Dictionary newlineKeywords = new Dictionary() { + { "SELECT", "\nSELECT" }, + { "FROM", "\nFROM" }, + { "WHERE", "\nWHERE" }, + { "JOIN", "\nJOIN" }, + { "ORDER BY", "\nORDER BY" }, { "OFFSET", "\nOFFSET" }, - { " )", "\n)" } - }; - - public static string FormatErrorMessage(string msg) - { - msg = newLine.Replace(msg, string.Empty); - - // remove sdk version in the error message which can change in the future. - // e.g. - msg = sdkVersion.Replace(msg, string.Empty); - - // remove activity Id - msg = activityId.Replace(msg, string.Empty); - - return msg; - } - - internal LinqTestOutput(string sqlQuery, string serializedResults, string errorMsg, string inputData) - { - this.SqlQuery = FormatSql(sqlQuery); - this.Results = serializedResults; - this.ErrorMessage = errorMsg; - this.InputData = inputData; - } - - public static String FormatSql(string sqlQuery) - { - const string subqueryCue = "(SELECT"; - bool hasSubquery = sqlQuery.IndexOf(subqueryCue, StringComparison.OrdinalIgnoreCase) > 0; - - StringBuilder sb = new StringBuilder(sqlQuery); - foreach (KeyValuePair kv in newlineKeywords) - { - sb.Replace(kv.Key, kv.Value); - } - - if (!hasSubquery) return sb.ToString(); - - const string oneTab = " "; - const string startCue = "SELECT"; - const string endCue = ")"; - - string[] tokens = sb.ToString().Split('\n'); - bool firstSelect = true; - sb.Length = 0; - StringBuilder indentSb = new StringBuilder(); - for (int i = 0; i < tokens.Length; ++i) - { - if (tokens[i].StartsWith(startCue, StringComparison.OrdinalIgnoreCase)) - { - if (!firstSelect) indentSb.Append(oneTab); else firstSelect = false; - - } - else if (tokens[i].StartsWith(endCue, StringComparison.OrdinalIgnoreCase)) - { - indentSb.Length -= oneTab.Length; - } - - sb.Append(indentSb).Append(tokens[i]).Append("\n"); - } - - return sb.ToString(); - } - - public override void SerializeAsXml(XmlWriter xmlWriter) - { - xmlWriter.WriteStartElement(nameof(this.SqlQuery)); - xmlWriter.WriteCData(this.SqlQuery); - xmlWriter.WriteEndElement(); - if (this.InputData != null) - { - xmlWriter.WriteStartElement("InputData"); - xmlWriter.WriteCData(this.InputData); - xmlWriter.WriteEndElement(); - } - if (this.Results != null) - { - xmlWriter.WriteStartElement("Results"); - xmlWriter.WriteCData(this.Results); - xmlWriter.WriteEndElement(); - } - if (this.ErrorMessage != null) - { - xmlWriter.WriteStartElement("ErrorMessage"); - xmlWriter.WriteCData(LinqTestOutput.FormatErrorMessage(this.ErrorMessage)); - xmlWriter.WriteEndElement(); - } - } - } - - class SystemTextJsonLinqSerializer : CosmosLinqSerializer - { - private readonly JsonObjectSerializer systemTextJsonSerializer; - - public SystemTextJsonLinqSerializer(JsonSerializerOptions jsonSerializerOptions) - { - this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); - } - - public override T FromStream(Stream stream) - { - if (stream == null) - throw new ArgumentNullException(nameof(stream)); - - using (stream) - { - if (stream.CanSeek && stream.Length == 0) - { - return default; - } - - if (typeof(Stream).IsAssignableFrom(typeof(T))) - { - return (T)(object)stream; - } - - return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); - } - } - - public override Stream ToStream(T input) - { - MemoryStream streamPayload = new MemoryStream(); - this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); - streamPayload.Position = 0; - return streamPayload; - } - - public override string SerializeMemberName(MemberInfo memberInfo) - { + { "GROUP BY", "\nGROUP BY" }, + { " )", "\n)" } + }; + + public static string FormatErrorMessage(string msg) + { + msg = newLine.Replace(msg, string.Empty); + + // remove sdk version in the error message which can change in the future. + // e.g. + msg = sdkVersion.Replace(msg, string.Empty); + + // remove activity Id + msg = activityId.Replace(msg, string.Empty); + + return msg; + } + + internal LinqTestOutput(string sqlQuery, string serializedResults, string errorMsg, string inputData) + { + this.SqlQuery = FormatSql(sqlQuery); + this.Results = serializedResults; + this.ErrorMessage = errorMsg; + this.InputData = inputData; + } + + public static String FormatSql(string sqlQuery) + { + const string subqueryCue = "(SELECT"; + bool hasSubquery = sqlQuery.IndexOf(subqueryCue, StringComparison.OrdinalIgnoreCase) > 0; + + StringBuilder sb = new StringBuilder(sqlQuery); + foreach (KeyValuePair kv in newlineKeywords) + { + sb.Replace(kv.Key, kv.Value); + } + + if (!hasSubquery) return sb.ToString(); + + const string oneTab = " "; + const string startCue = "SELECT"; + const string endCue = ")"; + string[] tokens = sb.ToString().Split('\n'); + bool firstSelect = true; + sb.Length = 0; + StringBuilder indentSb = new StringBuilder(); + for (int i = 0; i < tokens.Length; ++i) + { + if (tokens[i].StartsWith(startCue, StringComparison.OrdinalIgnoreCase)) + { + if (!firstSelect) indentSb.Append(oneTab); else firstSelect = false; + + } + else if (tokens[i].StartsWith(endCue, StringComparison.OrdinalIgnoreCase)) + { + indentSb.Length -= oneTab.Length; + } + + sb.Append(indentSb).Append(tokens[i]).Append("\n"); + } + + return sb.ToString(); + } + + public override void SerializeAsXml(XmlWriter xmlWriter) + { + xmlWriter.WriteStartElement(nameof(this.SqlQuery)); + xmlWriter.WriteCData(this.SqlQuery); + xmlWriter.WriteEndElement(); + if (this.InputData != null) + { + xmlWriter.WriteStartElement("InputData"); + xmlWriter.WriteCData(this.InputData); + xmlWriter.WriteEndElement(); + } + if (this.Results != null) + { + xmlWriter.WriteStartElement("Results"); + xmlWriter.WriteCData(this.Results); + xmlWriter.WriteEndElement(); + } + if (this.ErrorMessage != null) + { + xmlWriter.WriteStartElement("ErrorMessage"); + xmlWriter.WriteCData(LinqTestOutput.FormatErrorMessage(this.ErrorMessage)); + xmlWriter.WriteEndElement(); + } + } + } + + class SystemTextJsonLinqSerializer : CosmosLinqSerializer + { + private readonly JsonObjectSerializer systemTextJsonSerializer; + + public SystemTextJsonLinqSerializer(JsonSerializerOptions jsonSerializerOptions) + { + this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); + } + + public override T FromStream(Stream stream) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + using (stream) + { + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + + return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); + } + } + + public override Stream ToStream(T input) + { + MemoryStream streamPayload = new MemoryStream(); + this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); + streamPayload.Position = 0; + return streamPayload; + } + + public override string SerializeMemberName(MemberInfo memberInfo) + { System.Text.Json.Serialization.JsonExtensionDataAttribute jsonExtensionDataAttribute = memberInfo.GetCustomAttribute(true); if (jsonExtensionDataAttribute != null) @@ -881,52 +881,52 @@ namespace Microsoft.Azure.Cosmos.Services.Management.Tests return null; } - JsonPropertyNameAttribute jsonPropertyNameAttribute = memberInfo.GetCustomAttribute(true); - - string memberName = !string.IsNullOrEmpty(jsonPropertyNameAttribute?.Name) - ? jsonPropertyNameAttribute.Name - : memberInfo.Name; - - return memberName; - } + JsonPropertyNameAttribute jsonPropertyNameAttribute = memberInfo.GetCustomAttribute(true); + + string memberName = !string.IsNullOrEmpty(jsonPropertyNameAttribute?.Name) + ? jsonPropertyNameAttribute.Name + : memberInfo.Name; + + return memberName; + } } - class SystemTextJsonSerializer : CosmosSerializer - { - private readonly JsonObjectSerializer systemTextJsonSerializer; - - public SystemTextJsonSerializer(JsonSerializerOptions jsonSerializerOptions) - { - this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); - } - - public override T FromStream(Stream stream) - { - if (stream == null) - throw new ArgumentNullException(nameof(stream)); - - using (stream) - { - if (stream.CanSeek && stream.Length == 0) - { - return default; - } - - if (typeof(Stream).IsAssignableFrom(typeof(T))) - { - return (T)(object)stream; - } - - return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); - } - } - - public override Stream ToStream(T input) - { - MemoryStream streamPayload = new MemoryStream(); - this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); - streamPayload.Position = 0; - return streamPayload; - } - } -} + class SystemTextJsonSerializer : CosmosSerializer + { + private readonly JsonObjectSerializer systemTextJsonSerializer; + + public SystemTextJsonSerializer(JsonSerializerOptions jsonSerializerOptions) + { + this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); + } + + public override T FromStream(Stream stream) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + using (stream) + { + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + + return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); + } + } + + public override Stream ToStream(T input) + { + MemoryStream streamPayload = new MemoryStream(); + this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); + streamPayload.Position = 0; + return streamPayload; + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj index 0cb2585dc..509e89c0a 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj @@ -193,6 +193,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/GroupByClauseSqlParserBaselineTests.Tests.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/GroupByClauseSqlParserBaselineTests.Tests.xml index 0241dc23e..db690101d 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/GroupByClauseSqlParserBaselineTests.Tests.xml +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/GroupByClauseSqlParserBaselineTests.Tests.xml @@ -5,7 +5,7 @@ - + @@ -14,7 +14,7 @@ - + @@ -23,7 +23,7 @@ - + diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/SqlObjectVisitorBaselineTests.SqlQueries.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/SqlObjectVisitorBaselineTests.SqlQueries.xml index fe02fc060..94eeeffc0 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/SqlObjectVisitorBaselineTests.SqlQueries.xml +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/BaselineTest/TestBaseline/SqlObjectVisitorBaselineTests.SqlQueries.xml @@ -989,16 +989,17 @@ OFFSET 0 LIMIT 0 }]]> - + -245344741 - + @@ -1127,18 +1128,19 @@ OFFSET 0 LIMIT 0 }]]> - + 51808704 - + @@ -1267,18 +1269,19 @@ OFFSET 0 LIMIT 0 }]]> - + -1922520573 - + @@ -1407,18 +1410,19 @@ ARRAY( }]]> - + 1317938775 - + \ No newline at end of file