From f3980e5bd5886f653d216a08e0e4f3814579a632 Mon Sep 17 00:00:00 2001 From: Christopher Jolly Date: Fri, 6 Dec 2024 01:18:11 +0800 Subject: [PATCH 1/3] Handle first class span for sql server --- .../CosmosSqlTranslatingExpressionVisitor.cs | 2 +- .../Internal/PrecompiledQueryCodeGenerator.cs | 4 +- ...yExpressionTranslatingExpressionVisitor.cs | 14 +- ...elationalQueryFilterRewritingConvention.cs | 2 +- .../ByteArraySequenceEqualTranslator.cs | 7 + .../Translators/ContainsTranslator.cs | 4 +- .../RelationalLiftableConstantProcessor.cs | 2 +- ...lationalSqlTranslatingExpressionVisitor.cs | 260 ++++++++++- ...qlServerSqlTranslatingExpressionVisitor.cs | 135 ++++++ .../SqlServerByteArrayMethodTranslator.cs | 47 +- .../Internal/ExpressionExtensions.cs | 4 +- .../Query/EvaluatableExpressionFilter.cs | 3 +- .../Internal/ExpressionTreeFuncletizer.cs | 26 +- src/EFCore/Query/LiftableConstantProcessor.cs | 2 +- src/EFCore/Query/QueryRootProcessor.cs | 13 +- .../ShapedQueryCompilingExpressionVisitor.cs | 13 +- src/Shared/ExpressionExtensions.cs | 6 +- src/Shared/MemoryExtensionsMethods.cs | 430 ++++++++++++++++++ src/Shared/SharedTypeExtensions.cs | 3 +- .../CSharpMigrationOperationGeneratorTest.cs | 4 +- .../PrimitiveCollectionsQueryTestBase.cs | 6 +- .../Internal/EntityTypeTest.BaseType.cs | 2 +- .../BytesToStringConverterTest.cs | 2 +- .../StringToBytesConverterTest.cs | 2 +- 24 files changed, 934 insertions(+), 59 deletions(-) create mode 100644 src/Shared/MemoryExtensionsMethods.cs diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 02329b5b826..9663e205dc2 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -1210,7 +1210,7 @@ private static bool TryEvaluateToConstant(Expression expression, [NotNullWhen(tr { sqlConstantExpression = new SqlConstantExpression( Expression.Lambda>(Expression.Convert(expression, typeof(object))) - .Compile(preferInterpretation: true) + .Compile(preferInterpretation: false) .Invoke(), expression.Type, null); diff --git a/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs index de2957a65d7..c8309dcdc27 100644 --- a/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs +++ b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs @@ -179,7 +179,7 @@ public virtual IReadOnlyList GeneratePrecompiledQueries( }; penultimateOperator = Expression.Lambda>(penultimateOperator) - .Compile(preferInterpretation: true)().Expression; + .Compile(preferInterpretation: false)().Expression; // Pass the query through EF's query pipeline; this returns the query's executor function, which can produce an enumerable // that invokes the query. @@ -817,7 +817,7 @@ void GenerateCapturedVariableExtractors( code .Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",") .AppendLine($"Expression.Lambda>(Expression.Convert({variableName}, typeof(object)))") - .AppendLine(".Compile(preferInterpretation: true)") + .AppendLine(".Compile(preferInterpretation: false)") .AppendLine(".Invoke());"); } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 660d0c79998..f77d9a63daa 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -196,11 +196,13 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) : anySubquery); } - static Expression RemoveConvert(Expression e) - => e is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary - ? RemoveConvert(unary.Operand) - : e; - } + static Expression? RemoveConvert(Expression? expression) + => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression + ? RemoveConvert(unaryExpression.Operand) + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Object) + : expression; +} } } @@ -1491,7 +1493,7 @@ when CanEvaluate(memberInitExpression): private static ConstantExpression GetValue(Expression expression) => Expression.Constant( Expression.Lambda>(Expression.Convert(expression, typeof(object))) - .Compile(preferInterpretation: true) + .Compile(preferInterpretation: false) .Invoke(), expression.Type); diff --git a/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs b/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs index 8453e441377..d7014bd0573 100644 --- a/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs +++ b/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs @@ -85,7 +85,7 @@ or nameof(RelationalQueryableExtensions.FromSql)) { var formattableString = Expression.Lambda>( Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))) - .Compile(preferInterpretation: true) + .Compile(preferInterpretation: false) .Invoke(); sql = formattableString.Format; diff --git a/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs b/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs index 45cb358e4f3..3492bb5266b 100644 --- a/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs @@ -44,6 +44,13 @@ public ByteArraySequenceEqualTranslator(ISqlExpressionFactory sqlExpressionFacto return _sqlExpressionFactory.Equal(arguments[0], arguments[1]); } + if (method.IsGenericMethod + && method.GetGenericMethodDefinition().Equals(MemoryExtensionsMethods.SequenceEqual) + && arguments[0].Type == typeof(byte[])) + { + return _sqlExpressionFactory.Equal(arguments[0], arguments[1]); + } + return null; } } diff --git a/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs b/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs index 04cddb50d17..44d666a6586 100644 --- a/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs @@ -45,14 +45,14 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) // Identify static Enumerable.Contains and instance List.Contains if (method.IsGenericMethod - && method.GetGenericMethodDefinition() == EnumerableMethods.Contains + && (method.GetGenericMethodDefinition() == EnumerableMethods.Contains || method.GetGenericMethodDefinition() == MemoryExtensionsMethods.Contains) && ValidateValues(arguments[0])) { (itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[1]), arguments[0]); } if (arguments.Count == 1 - && method.IsContainsMethod() + && (method.IsContainsMethod()) && instance != null && ValidateValues(instance)) { diff --git a/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs b/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs index 2d0b9fa8c0c..914b839b008 100644 --- a/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs +++ b/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs @@ -36,7 +36,7 @@ protected override ConstantExpression InlineConstant(LiftableConstantExpression if (liftableConstant.ResolverExpression is Expression> resolverExpression) { - var resolver = resolverExpression.Compile(preferInterpretation: true); + var resolver = resolverExpression.Compile(preferInterpretation: false); var value = resolver(_relationalMaterializerLiftableConstantContext); return Expression.Constant(value, liftableConstant.Type); } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 9d26798eade..78bdb292ad7 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -4,8 +4,10 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; using System.Text; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using MemoryExtensions = System.MemoryExtensions; namespace Microsoft.EntityFrameworkCore.Query; @@ -57,6 +59,7 @@ private static readonly MethodInfo StringEqualsWithStringComparisonStatic private readonly QueryCompilationContext _queryCompilationContext; private readonly IModel _model; private readonly ISqlExpressionFactory _sqlExpressionFactory; + private readonly SqlAliasManager _sqlAliasManager; private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; private bool _throwForNotTranslatedEfProperty; @@ -78,6 +81,7 @@ public RelationalSqlTranslatingExpressionVisitor( _model = queryCompilationContext.Model; _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; _throwForNotTranslatedEfProperty = true; + _sqlAliasManager = ((RelationalQueryCompilationContext)queryCompilationContext).SqlAliasManager; } /// @@ -273,11 +277,13 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) : anySubquery); } - static Expression RemoveConvert(Expression e) - => e is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary - ? RemoveConvert(unary.Operand) - : e; - } + static Expression? RemoveConvert(Expression? expression) + => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression + ? RemoveConvert(unaryExpression.Operand) + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Object) + : expression; +} } } @@ -645,9 +651,30 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result)) + if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result,out var property)) { + /*if (property is IProperty { IsPrimitiveCollection: true } regularProperty + && result is SqlExpression sqlExpression + && TranslatePrimitiveCollection( + sqlExpression, regularProperty, _sqlAliasManager.GenerateTableAlias(GenerateTableAlias(sqlExpression))) is + { } primitiveCollectionTranslation) + { + return primitiveCollectionTranslation; + }*/ return result; + + string GenerateTableAlias(SqlExpression sqlExpression) + => sqlExpression switch + { + ColumnExpression c => c.Name, + JsonScalarExpression jsonScalar + => jsonScalar.Path.LastOrDefault(s => s.PropertyName is not null) is PathSegment lastPropertyNameSegment + ? lastPropertyNameSegment.PropertyName! + : GenerateTableAlias(jsonScalar.Json), + ScalarSubqueryExpression scalarSubquery => scalarSubquery.Subquery.Projection[0].Alias, + + _ => "collection" + }; } var message = CoreStrings.QueryUnableToTranslateEFProperty(methodCallExpression.Print()); @@ -678,6 +705,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp switch (methodCallExpression) { case + { + Method.Name:"op_Implicit" + }: + return Visit(methodCallExpression.Arguments[0]); + case { Method.Name: nameof(object.Equals), Object: not null, @@ -784,6 +816,151 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp break; } + case + { + Method: + { + Name: "Contains", + IsGenericMethod: true + } + } when method.DeclaringType == typeof(MemoryExtensions): + { + if (arguments[0] is MethodCallExpression { Method.Name: "op_Implicit" } m2 + && m2.Arguments[0] is NewArrayExpression newArray) + { + var expressions = new List(); + foreach (var exp in newArray.Expressions) + { + expressions.Add((SqlExpression)Visit((exp))); + } + return _sqlExpressionFactory.In((SqlExpression)Visit((arguments[1])), expressions); + } + var enumerable = Visit(arguments[0]); + var item = Visit(arguments[1]); + + if (TryRewriteContainsEntity( + enumerable, + item == QueryCompilationContext.NotTranslatedExpression ? arguments[1] : item, out var result)) + { + return result; + } + + //Comes from Queryable pathway. TranslatePrimitiveCollection creates the OPENJSON if needed + //Currently doesn't work. + //Jsonreaderwriter needs to handle Span and ref structs + + if (enumerable is SqlParameterExpression sqlParameterExpression) + { + var primitiveCollectionsBehavior = RelationalOptionsExtension.Extract(_queryCompilationContext.ContextOptions) + .ParameterizedCollectionTranslationMode; + + var tableAlias = _sqlAliasManager.GenerateTableAlias(sqlParameterExpression.Name.TrimStart('_')); + if (sqlParameterExpression.ShouldBeConstantized + || (primitiveCollectionsBehavior == ParameterizedCollectionTranslationMode.Constantize)) + { + var valuesExpression = new ValuesExpression( + tableAlias, + sqlParameterExpression, + [RelationalQueryableMethodTranslatingExpressionVisitor.ValuesOrderingColumnName, RelationalQueryableMethodTranslatingExpressionVisitor.ValuesValueColumnName]); + return CreateShapedQueryExpressionForValuesExpression( + valuesExpression, + tableAlias, + sqlParameterExpression.TypeMapping!.ElementTypeMapping!.GetType(), + sqlParameterExpression.TypeMapping, + sqlParameterExpression.IsNullable); + } + + var param = sqlParameterExpression; + if (sqlParameterExpression.Type.IsGenericType && sqlParameterExpression.Type.GetGenericTypeDefinition() == typeof(Span<>)) + { + var newElement = sqlParameterExpression.Type.GetSequenceType(); + param = new SqlParameterExpression( + sqlParameterExpression.Name, newElement.MakeArrayType(), sqlParameterExpression.IsNullable, + sqlParameterExpression.ShouldBeConstantized, sqlParameterExpression.TypeMapping); + } + + var primitiveresult = TranslatePrimitiveCollection(param, property: null, tableAlias); + var shaperExpression = primitiveresult?.ShaperExpression; + // No need to check ConvertChecked since this is convert node which we may have added during projection + if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand.Type.IsNullableType() + && unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type) + { + shaperExpression = unaryExpression.Operand; + } + if (primitiveresult?.QueryExpression is SelectExpression selectExpression + && shaperExpression is ProjectionBindingExpression projectionBindingExpression + && selectExpression.GetProjection(projectionBindingExpression) is SqlExpression projection) + { + // Translate to IN with a subquery. + // Note that because of null semantics, this may get transformed to an EXISTS subquery in SqlNullabilityProcessor. + var subquery = (SelectExpression)primitiveresult.QueryExpression; + if (subquery.Limit == null + && subquery.Offset == null) + { + subquery.ClearOrdering(); + } + + subquery.IsDistinct = false; + + subquery.ReplaceProjection(new List { projection }); + subquery.ApplyProjection(); + + var translation1 = _sqlExpressionFactory.In((SqlExpression)item, subquery); + subquery = new SelectExpression(translation1, _sqlAliasManager); + return translation1; + } + } + + if (enumerable is ColumnExpression columnExpression && columnExpression.Type.IsArray && !(columnExpression.Type.GetElementType() == typeof(byte))) + { + var tableAlias = _sqlAliasManager.GenerateTableAlias(columnExpression.Name.TrimStart('_')); + var primitiveresult = TranslatePrimitiveCollection(columnExpression, property: null, tableAlias); + var shaperExpression = primitiveresult?.ShaperExpression; + // No need to check ConvertChecked since this is convert node which we may have added during projection + if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand.Type.IsNullableType() + && unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type) + { + shaperExpression = unaryExpression.Operand; + } + if (primitiveresult?.QueryExpression is SelectExpression selectExpression + && shaperExpression is ProjectionBindingExpression projectionBindingExpression + && selectExpression.GetProjection(projectionBindingExpression) is SqlExpression projection) + { + // Translate to IN with a subquery. + // Note that because of null semantics, this may get transformed to an EXISTS subquery in SqlNullabilityProcessor. + var subquery = (SelectExpression)primitiveresult.QueryExpression; + if (subquery.Limit == null + && subquery.Offset == null) + { + subquery.ClearOrdering(); + } + + subquery.IsDistinct = false; + + subquery.ReplaceProjection(new List { projection }); + subquery.ApplyProjection(); + + var translation1 = _sqlExpressionFactory.In((SqlExpression)item, subquery); + subquery = new SelectExpression(translation1, _sqlAliasManager); + return translation1; + } + } + + if (enumerable is SqlExpression sqlEnumerable + && item is SqlExpression sqlItem) + { + scalarArguments = [sqlEnumerable, sqlItem]; + } + else + { + return QueryCompilationContext.NotTranslatedExpression; + } + + break; + } + case { Arguments: [var argument] } when method.IsContainsMethod(): { var enumerable = Visit(methodCallExpression.Object); @@ -973,6 +1150,50 @@ Expression TranslateAsSubquery(Expression expression) } } + private ShapedQueryExpression CreateShapedQueryExpressionForValuesExpression( + ValuesExpression valuesExpression, + string tableAlias, + Type elementType, + RelationalTypeMapping? inferredTypeMapping, + bool encounteredNull) + { + // Note: we leave the element type mapping null, to allow it to get inferred based on queryable operators composed on top. + var valueColumn = new ColumnExpression( + RelationalQueryableMethodTranslatingExpressionVisitor.ValuesValueColumnName, + tableAlias, + elementType.UnwrapNullableType(), + typeMapping: inferredTypeMapping, + nullable: encounteredNull); + var orderingColumn = new ColumnExpression( + RelationalQueryableMethodTranslatingExpressionVisitor.ValuesOrderingColumnName, + tableAlias, + typeof(int), + typeMapping: Dependencies.TypeMappingSource.FindMapping(typeof(int), Dependencies.Model), + nullable: false); + + var selectExpression = new SelectExpression( + [valuesExpression], + valueColumn, + identifier: [(orderingColumn, orderingColumn.TypeMapping!.Comparer)], + _sqlAliasManager); + + selectExpression.AppendOrdering(new OrderingExpression(orderingColumn, ascending: true)); + + Expression shaperExpression = new ProjectionBindingExpression( + selectExpression, new ProjectionMember(), encounteredNull ? elementType.MakeNullable() : elementType); + + if (elementType != shaperExpression.Type) + { + Check.DebugAssert( + elementType.MakeNullable() == shaperExpression.Type, + "expression.Type must be nullable of targetType"); + + shaperExpression = Expression.Convert(shaperExpression, elementType); + } + + return new ShapedQueryExpression(selectExpression, shaperExpression); + } + /// protected override Expression VisitNew(NewExpression newExpression) => TryEvaluateToConstant(newExpression, out var sqlConstantExpression) @@ -1141,6 +1362,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return sqlOperand!; } + if (unaryExpression.Type.IsGenericType && unaryExpression.Type.GetGenericArguments()[0] == sqlOperand!.Type.GetElementType()) + { + return sqlOperand!; + } + // Introduce explicit cast only if the target type is mapped else we need to client eval if (unaryExpression.Type == typeof(object) || Dependencies.TypeMappingSource.FindMapping(unaryExpression.Type, Dependencies.Model) != null) @@ -1604,6 +1830,26 @@ private static Expression TryRemoveImplicitConvert(Expression expression) return expression; } + /// + /// Translates a parameter or column collection of primitive values. Providers can override this to translate e.g. int[] columns or + /// parameters to a queryable table (OPENJSON on SQL Server, unnest on PostgreSQL...). The default implementation always returns + /// (no translation). + /// + /// The expression to try to translate as a primitive collection expression. + /// + /// If the primitive collection is a property, contains the for that property. Otherwise, the collection + /// represents a parameter, and this contains . + /// + /// + /// Provides an alias to be used for the table returned from translation, which will represent the collection. + /// + /// A if the translation was successful, otherwise . + protected virtual ShapedQueryExpression? TranslatePrimitiveCollection( + SqlExpression sqlExpression, + IProperty? property, + string tableAlias) + => null; + private static Expression RemoveObjectConvert(Expression expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression && unaryExpression.Type == typeof(object) @@ -1627,7 +1873,7 @@ private static bool TryEvaluateToConstant(Expression expression, [NotNullWhen(tr { sqlConstantExpression = new SqlConstantExpression( Expression.Lambda>(Expression.Convert(expression, typeof(object))) - .Compile(preferInterpretation: true) + .Compile(preferInterpretation: false) .Invoke(), expression.Type, typeMapping: null); diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index c0a7d72d43c..938603057b3 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -6,6 +6,7 @@ using System.Text; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal; +using Microsoft.EntityFrameworkCore.SqlServer.Internal; using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions; namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; @@ -267,6 +268,19 @@ bool TryTranslateStartsEndsWithContains( StartsEndsWithContains methodType, [NotNullWhen(true)] out SqlExpression? translation) { + if (pattern is UnaryExpression + { + NodeType:ExpressionType.Convert + } unary) + { + if (unary.Type == typeof(Span)) + { + pattern = unary.Operand; + } + } + + instance = RemoveConvert(instance); + pattern = RemoveConvert(pattern); if (Visit(instance) is not SqlExpression translatedInstance || Visit(pattern) is not SqlExpression translatedPattern) { @@ -648,4 +662,125 @@ private static bool TranslationFailed(Expression? original, Expression? translat private static string? GetProviderType(SqlExpression expression) => expression.TypeMapping?.StoreType; + + [return: NotNullIfNotNull(nameof(expression))] + private static Expression? RemoveConvert(Expression? expression) + => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression + ? RemoveConvert(unaryExpression.Operand) + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Arguments[0]) + : expression; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslatePrimitiveCollection( + SqlExpression sqlExpression, + IProperty? property, + string tableAlias) + { + if (_sqlServerSingletonOptions.EngineType == SqlServerEngineType.SqlServer + && _sqlServerSingletonOptions.SqlServerCompatibilityLevel < 130) + { + AddTranslationErrorDetails( + SqlServerStrings.CompatibilityLevelTooLowForScalarCollections(_sqlServerSingletonOptions.SqlServerCompatibilityLevel)); + + return null; + } + + if (_sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSql + && _sqlServerSingletonOptions.AzureSqlCompatibilityLevel < 130) + { + AddTranslationErrorDetails( + SqlServerStrings.CompatibilityLevelTooLowForScalarCollections(_sqlServerSingletonOptions.AzureSqlCompatibilityLevel)); + + return null; + } + + // Generate the OPENJSON function expression, and wrap it in a SelectExpression. + + // Note that where the elementTypeMapping is known (i.e. collection columns), we immediately generate OPENJSON with a WITH clause + // (i.e. with a columnInfo), which determines the type conversion to apply to the JSON elements coming out. + // For parameter collections, the element type mapping will only be inferred and applied later (see + // SqlServerInferredTypeMappingApplier below), at which point the we'll apply it to add the WITH clause. + var elementTypeMapping = (RelationalTypeMapping?)sqlExpression.TypeMapping?.ElementTypeMapping; + + var openJsonExpression = elementTypeMapping is null + ? new SqlServerOpenJsonExpression(tableAlias, sqlExpression) + : new SqlServerOpenJsonExpression( + tableAlias, sqlExpression, + columnInfos: new[] + { + new SqlServerOpenJsonExpression.ColumnInfo + { + Name = "value", + TypeMapping = elementTypeMapping, + Path = [] + } + }); + + var elementClrType = sqlExpression.Type.GetSequenceType(); + + // If this is a collection property, get the element's nullability out of metadata. Otherwise, this is a parameter property, in + // which case we only have the CLR type (note that we cannot produce different SQLs based on the nullability of an *element* in + // a parameter collection - our caching mechanism only supports varying by the nullability of the parameter itself (i.e. the + // collection). + var isElementNullable = property?.GetElementType()!.IsNullable; + + var keyColumnTypeMapping = _typeMappingSource.FindMapping("nvarchar(4000)")!; +#pragma warning disable EF1001 // Internal EF Core API usage. + var selectExpression = new SelectExpression( + [openJsonExpression], + new ColumnExpression( + "value", + tableAlias, + elementClrType.UnwrapNullableType(), + elementTypeMapping, + isElementNullable ?? elementClrType.IsNullableType()), + identifier: + [ + (new ColumnExpression("key", tableAlias, typeof(string), keyColumnTypeMapping, nullable: false), + keyColumnTypeMapping.Comparer) + ], + _queryCompilationContext.SqlAliasManager); +#pragma warning restore EF1001 // Internal EF Core API usage. + + // OPENJSON doesn't guarantee the ordering of the elements coming out; when using OPENJSON without WITH, a [key] column is returned + // with the JSON array's ordering, which we can ORDER BY; this option doesn't exist with OPENJSON with WITH, unfortunately. + // However, OPENJSON with WITH has better performance, and also applies JSON-specific conversions we cannot be done otherwise + // (e.g. OPENJSON with WITH does base64 decoding for VARBINARY). + // Here we generate OPENJSON with WITH, but also add an ordering by [key] - this is a temporary invalid representation. + // In SqlServerQueryTranslationPostprocessor, we'll post-process the expression; if the ORDER BY was stripped (e.g. because of + // IN, EXISTS or a set operation), we'll just leave the OPENJSON with WITH. If not, we'll convert the OPENJSON with WITH to an + // OPENJSON without WITH. + // Note that the OPENJSON 'key' column is an nvarchar - we convert it to an int before sorting. + selectExpression.AppendOrdering( + new OrderingExpression( + _sqlExpressionFactory.Convert( + selectExpression.CreateColumnExpression( + openJsonExpression, + "key", + typeof(string), + typeMapping: _typeMappingSource.FindMapping("nvarchar(4000)"), + columnNullable: false), + typeof(int), + _typeMappingSource.FindMapping(typeof(int))), + ascending: true)); + + var shaperExpression = (Expression)new ProjectionBindingExpression( + selectExpression, new ProjectionMember(), elementClrType.MakeNullable()); + if (shaperExpression.Type != elementClrType) + { + Check.DebugAssert( + elementClrType.MakeNullable() == shaperExpression.Type, + "expression.Type must be nullable of targetType"); + + shaperExpression = Expression.Convert(shaperExpression, elementClrType); + } + + return new ShapedQueryExpression(selectExpression, shaperExpression); + } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs index 86c07a0aa4d..73ac4638367 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using MemoryExtensions = System.MemoryExtensions; // ReSharper disable once CheckNamespace namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; @@ -59,19 +60,41 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac typeof(int)), _sqlExpressionFactory.Constant(0)); } + } - if (methodDefinition.Equals(EnumerableMethods.FirstWithoutPredicate) - && arguments[0].Type == typeof(byte[])) - { - return _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Function( - "SUBSTRING", - [arguments[0], _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - typeof(byte[])), - method.ReturnType); - } + if (method.IsGenericMethod + && method.GetGenericMethodDefinition() == MemoryExtensionsMethods.Contains + && (arguments[0].Type == typeof(Span) || arguments[0].Type == typeof(byte[]))) + { + var source = arguments[0]; + var sourceTypeMapping = source.TypeMapping; + + var value = arguments[1] is SqlConstantExpression constantValue + ? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping) + : _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping); + + return _sqlExpressionFactory.GreaterThan( + _sqlExpressionFactory.Function( + "CHARINDEX", + [value, source], + nullable: true, + argumentsPropagateNullability: [true, true], + typeof(int)), + _sqlExpressionFactory.Constant(0)); + } + + if (method.IsGenericMethod + && method.GetGenericMethodDefinition().Equals(EnumerableMethods.FirstWithoutPredicate) + && arguments[0].Type == typeof(byte[])) + { + return _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( + "SUBSTRING", + [arguments[0], _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], + nullable: true, + argumentsPropagateNullability: [true, true, true], + typeof(byte[])), + method.ReturnType); } return null; diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index 4eb3c2e723b..0624e0a9eea 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -201,7 +201,9 @@ public static bool IsLogicalNot(this UnaryExpression sqlUnaryExpression) private static Expression? RemoveConvert(Expression? expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression ? RemoveConvert(unaryExpression.Operand) - : expression; + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Object) + : expression; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore/Query/EvaluatableExpressionFilter.cs b/src/EFCore/Query/EvaluatableExpressionFilter.cs index 917df7f8514..97a378d2986 100644 --- a/src/EFCore/Query/EvaluatableExpressionFilter.cs +++ b/src/EFCore/Query/EvaluatableExpressionFilter.cs @@ -79,7 +79,8 @@ public virtual bool IsEvaluatableExpression(Expression expression, IModel model) || Equals(method, RandomNextNoArgs) || Equals(method, RandomNextOneArg) || Equals(method, RandomNextTwoArgs) - || method.DeclaringType == typeof(DbFunctionsExtensions)) + || method.DeclaringType == typeof(DbFunctionsExtensions) + || method.Name == "op_Implicit")//should this be not evaluatble? { return false; } diff --git a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs index 1486ce90137..97b4b5af67d 100644 --- a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs +++ b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs @@ -1854,7 +1854,19 @@ private static StateType CombineStateTypes(StateType stateType1, StateType state return result; } - var value = Evaluate(evaluatableRoot, out var parameterName, out var isContextAccessor); + object? value = null; + string? parameterName = null; + bool isContextAccessor = false; + if (evaluatableRoot is MethodCallExpression { Method.Name: "op_Implicit" } evaluatableRootMethod) + { + //can we get the arguments state so that we can do notEvaluatableAsRootHandler + value = Evaluate(evaluatableRootMethod.Arguments[0], out parameterName, out isContextAccessor); + } + else + { + value = Evaluate(evaluatableRoot, out parameterName, out isContextAccessor); + } + switch (value) { @@ -2043,11 +2055,13 @@ bool PreserveConvertNode(Expression expression) return Lambda(visited, _contextParameterReplacer.ContextParameterExpression); } - static Expression RemoveConvert(Expression expression) + static Expression? RemoveConvert(Expression? expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression ? RemoveConvert(unaryExpression.Operand) - : expression; - } + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Object) + : expression; +} switch (expression) { @@ -2078,6 +2092,10 @@ static Expression RemoveConvert(Expression expression) case MethodCallExpression methodCallExpression: parameterName = methodCallExpression.Method.Name; + if (parameterName == "op_Implicit") + { + return expression; + } break; case UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression diff --git a/src/EFCore/Query/LiftableConstantProcessor.cs b/src/EFCore/Query/LiftableConstantProcessor.cs index 5e384638cd8..2a76d78a94c 100644 --- a/src/EFCore/Query/LiftableConstantProcessor.cs +++ b/src/EFCore/Query/LiftableConstantProcessor.cs @@ -198,7 +198,7 @@ protected virtual ConstantExpression InlineConstant(LiftableConstantExpression l // Make sure there aren't any problematic un-lifted constants within the resolver expression. _unsupportedConstantChecker.Check(resolverExpression); - var resolver = resolverExpression.Compile(preferInterpretation: true); + var resolver = resolverExpression.Compile(preferInterpretation: false); var value = resolver(_materializerLiftableConstantContext); return Expression.Constant(value, liftableConstant.Type); diff --git a/src/EFCore/Query/QueryRootProcessor.cs b/src/EFCore/Query/QueryRootProcessor.cs index 34a645d35ca..ab2dc888424 100644 --- a/src/EFCore/Query/QueryRootProcessor.cs +++ b/src/EFCore/Query/QueryRootProcessor.cs @@ -114,7 +114,18 @@ when listInit.Type.TryGetElementType(typeof(IList<>)) is not null && listInit.Initializers.All(x => x.Arguments.Count == 1) && ShouldConvertToInlineQueryRoot(listInit): return new InlineQueryRootExpression(listInit.Initializers.Select(x => x.Arguments[0]).ToList(), elementClrType); - + case MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression: + { + if (elementClrType.GetGenericTypeDefinition() == typeof(Span<>)) + { + var baseType = elementClrType.GetGenericArguments()[0]; + var newArgument = VisitQueryRootCandidate(methodCallExpression.Arguments[0], baseType); + return newArgument == methodCallExpression.Arguments[0] + ? methodCallExpression + : Expression.Call(methodCallExpression.Method, newArgument); + } + goto default; + } default: return Visit(expression); } diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index fbc596366bf..69584cc6fb8 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -344,14 +344,11 @@ protected override Expression VisitExtension(Expression extensionExpression) : base.VisitExtension(extensionExpression); private static Expression? RemoveConvert(Expression? expression) - { - while (expression is { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked }) - { - expression = RemoveConvert(((UnaryExpression)expression).Operand); - } - - return expression; - } + => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression + ? RemoveConvert(unaryExpression.Operand) + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Object) + : expression; } private sealed class EntityMaterializerInjectingExpressionVisitor( diff --git a/src/Shared/ExpressionExtensions.cs b/src/Shared/ExpressionExtensions.cs index 0835ed7eb3f..aaddf4e98dd 100644 --- a/src/Shared/ExpressionExtensions.cs +++ b/src/Shared/ExpressionExtensions.cs @@ -40,10 +40,12 @@ public static LambdaExpression UnwrapLambdaFromQuote(this Expression expression) return expression; } - private static Expression RemoveConvert(Expression expression) + private static Expression? RemoveConvert(Expression? expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression ? RemoveConvert(unaryExpression.Operand) - : expression; + : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? + RemoveConvert(methodCallExpression.Object) + : expression; public static T GetConstantValue(this Expression expression) => RemoveConvert(expression) switch diff --git a/src/Shared/MemoryExtensionsMethods.cs b/src/Shared/MemoryExtensionsMethods.cs new file mode 100644 index 00000000000..275b9751abd --- /dev/null +++ b/src/Shared/MemoryExtensionsMethods.cs @@ -0,0 +1,430 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; + +namespace Microsoft.EntityFrameworkCore; + +internal static class MemoryExtensionsMethods +{ + //public static MethodInfo AggregateWithoutSeed { get; } + + //public static MethodInfo AggregateWithSeedWithoutSelector { get; } + + public static MethodInfo AggregateWithSeedSelector { get; } + + public static MethodInfo All { get; } + + public static MethodInfo AnyWithoutPredicate { get; } + + public static MethodInfo AnyWithPredicate { get; } + + //public static Append { get; } + + public static MethodInfo AsEnumerable { get; } + + public static MethodInfo Cast { get; } + + public static MethodInfo Concat { get; } + + public static MethodInfo Contains { get; } + + //public static MethodInfo ContainsWithComparer { get; } + + public static MethodInfo CountWithoutPredicate { get; } + + public static MethodInfo CountWithPredicate { get; } + + public static MethodInfo DefaultIfEmptyWithoutArgument { get; } + + public static MethodInfo DefaultIfEmptyWithArgument { get; } + + public static MethodInfo LongCountWithoutPredicate { get; } + + public static MethodInfo LongCountWithPredicate { get; } + + public static MethodInfo MaxWithoutSelector { get; } + + public static MethodInfo MaxWithSelector { get; } + + public static MethodInfo MinWithoutSelector { get; } + + public static MethodInfo MinWithSelector { get; } + + public static MethodInfo OfType { get; } + + public static MethodInfo OrderBy { get; } + + //public static MethodInfo OrderByWithComparer { get; } + + public static MethodInfo OrderByDescending { get; } + + //public static MethodInfo OrderByDescendingWithComparer { get; } + + //public static MethodInfo Prepend { get; } + + //public static MethodInfo Range { get; } + + //public static MethodInfo Repeat { get; } + + public static MethodInfo Reverse { get; } + + public static MethodInfo Select { get; } + + public static MethodInfo SelectWithOrdinal { get; } + + public static MethodInfo SelectManyWithoutCollectionSelector { get; } + + //public static MethodInfo SelectManyWithoutCollectionSelectorOrdinal { get; } + + public static MethodInfo SelectManyWithCollectionSelector { get; } + + //public static MethodInfo SelectManyWithCollectionSelectorOrdinal { get; } + + public static MethodInfo SequenceEqual { get; } + + //public static MethodInfo SequenceEqualWithComparer { get; } + + public static MethodInfo SingleWithoutPredicate { get; } + + public static MethodInfo SingleWithPredicate { get; } + + public static MethodInfo SingleOrDefaultWithoutPredicate { get; } + + public static MethodInfo SingleOrDefaultWithPredicate { get; } + + public static MethodInfo Skip { get; } + + public static MethodInfo SkipWhile { get; } + + //public static MethodInfo SkipWhileOrdinal { get; } + + public static MethodInfo Take { get; } + + public static MethodInfo TakeWhile { get; } + + //public static MethodInfo TakeWhileOrdinal { get; } + + public static MethodInfo ThenBy { get; } + + //public static MethodInfo ThenByWithComparer { get; } + + public static MethodInfo ThenByDescending { get; } + + //public static MethodInfo ThenByDescendingWithComparer { get; } + + public static MethodInfo ToArray { get; } + + //public static MethodInfo ToDictionaryWithKeySelector { get; } + //public static MethodInfo ToDictionaryWithKeySelectorAndComparer { get; } + //public static MethodInfo ToDictionaryWithKeyElementSelector { get; } + //public static MethodInfo ToDictionaryWithKeyElementSelectorAndComparer { get; } + + //public static MethodInfo ToHashSet { get; } + //public static MethodInfo ToHashSetWithComparer { get; } + + public static MethodInfo ToList { get; } + + //public static MethodInfo ToLookupWithKeySelector { get; } + //public static MethodInfo ToLookupWithKeySelectorAndComparer { get; } + //public static MethodInfo ToLookupWithKeyElementSelector { get; } + //public static MethodInfo ToLookupWithKeyElementSelectorAndComparer { get; } + + public static MethodInfo Union { get; } + + //public static MethodInfo UnionWithComparer { get; } + + public static MethodInfo Where { get; } + + //public static MethodInfo WhereOrdinal { get; } + + public static MethodInfo ZipWithSelector { get; } + + // private static Dictionary SumWithoutSelectorMethods { get; } + private static Dictionary SumWithSelectorMethods { get; } + + // private static Dictionary AverageWithoutSelectorMethods { get; } + private static Dictionary AverageWithSelectorMethods { get; } + private static Dictionary MaxWithoutSelectorMethods { get; } + private static Dictionary MaxWithSelectorMethods { get; } + private static Dictionary MinWithoutSelectorMethods { get; } + private static Dictionary MinWithSelectorMethods { get; } + + // Not currently used + // + // public static bool IsSumWithoutSelector(MethodInfo methodInfo) + // => SumWithoutSelectorMethods.Values.Contains(methodInfo); + // + // public static bool IsSumWithSelector(MethodInfo methodInfo) + // => methodInfo.IsGenericMethod + // && SumWithSelectorMethods.Values.Contains(methodInfo.GetGenericMethodDefinition()); + // + // public static bool IsAverageWithoutSelector(MethodInfo methodInfo) + // => AverageWithoutSelectorMethods.Values.Contains(methodInfo); + // + // public static bool IsAverageWithSelector(MethodInfo methodInfo) + // => methodInfo.IsGenericMethod + // && AverageWithSelectorMethods.Values.Contains(methodInfo.GetGenericMethodDefinition()); + // + // public static MethodInfo GetSumWithoutSelector(Type type) + // => SumWithoutSelectorMethods[type]; + + public static MethodInfo GetSumWithSelector(Type type) + => SumWithSelectorMethods[type]; + + // public static MethodInfo GetAverageWithoutSelector(Type type) + // => AverageWithoutSelectorMethods[type]; + + public static MethodInfo GetAverageWithSelector(Type type) + => AverageWithSelectorMethods[type]; + + public static MethodInfo GetMaxWithoutSelector(Type type) + => MaxWithoutSelectorMethods.GetValueOrDefault(type, MaxWithoutSelector); + + public static MethodInfo GetMaxWithSelector(Type type) + => MaxWithSelectorMethods.GetValueOrDefault(type, MaxWithSelector); + + public static MethodInfo GetMinWithoutSelector(Type type) + => MinWithoutSelectorMethods.GetValueOrDefault(type, MinWithoutSelector); + + public static MethodInfo GetMinWithSelector(Type type) + => MinWithSelectorMethods.GetValueOrDefault(type, MinWithSelector); + + static MemoryExtensionsMethods() + { + var queryableMethodGroups = typeof(Enumerable) + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) + .GroupBy(mi => mi.Name) + .ToDictionary(e => e.Key, l => l.ToList()); + + AggregateWithSeedSelector = GetMethod( + nameof(Enumerable.Aggregate), 3, + types => + [ + typeof(IEnumerable<>).MakeGenericType(types[0]), + types[1], + typeof(Func<,,>).MakeGenericType(types[1], types[0], types[1]), + typeof(Func<,>).MakeGenericType(types[1], types[2]) + ]); + + All = GetMethod( + nameof(Enumerable.All), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + AnyWithoutPredicate = GetMethod( + nameof(Enumerable.Any), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + AnyWithPredicate = GetMethod( + nameof(Enumerable.Any), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + AsEnumerable = GetMethod( + nameof(Enumerable.AsEnumerable), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + Cast = GetMethod(nameof(Enumerable.Cast), 1, _ => [typeof(IEnumerable)]); + + Concat = GetMethod( + nameof(Enumerable.Concat), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]); + + Contains = typeof(MemoryExtensions).GetMethods(BindingFlags.Static | BindingFlags.Public) + .First(m => m is { Name: nameof(MemoryExtensions.Contains), IsGenericMethodDefinition: true }); + + CountWithoutPredicate = GetMethod( + nameof(Enumerable.Count), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + CountWithPredicate = GetMethod( + nameof(Enumerable.Count), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + DefaultIfEmptyWithoutArgument = GetMethod( + nameof(Enumerable.DefaultIfEmpty), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + DefaultIfEmptyWithArgument = GetMethod( + nameof(Enumerable.DefaultIfEmpty), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), types[0]]); + + LongCountWithoutPredicate = GetMethod( + nameof(Enumerable.LongCount), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + LongCountWithPredicate = GetMethod( + nameof(Enumerable.LongCount), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + MaxWithoutSelector = GetMethod(nameof(Enumerable.Max), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + MaxWithSelector = GetMethod( + nameof(Enumerable.Max), 2, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + MinWithoutSelector = GetMethod(nameof(Enumerable.Min), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + MinWithSelector = GetMethod( + nameof(Enumerable.Min), 2, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + OfType = GetMethod(nameof(Enumerable.OfType), 1, _ => [typeof(IEnumerable)]); + + OrderBy = GetMethod( + nameof(Enumerable.OrderBy), 2, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + OrderByDescending = GetMethod( + nameof(Enumerable.OrderByDescending), 2, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + Reverse = GetMethod(nameof(Enumerable.Reverse), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + Select = GetMethod( + nameof(Enumerable.Select), 2, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + SelectWithOrdinal = GetMethod( + nameof(Enumerable.Select), 2, + types => + [ + typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,,>).MakeGenericType(types[0], typeof(int), types[1]) + ]); + + SelectManyWithoutCollectionSelector = GetMethod( + nameof(Enumerable.SelectMany), 2, + types => + [ + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType( + types[0], typeof(IEnumerable<>).MakeGenericType(types[1])) + ]); + + SelectManyWithCollectionSelector = GetMethod( + nameof(Enumerable.SelectMany), 3, + types => + [ + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType( + types[0], typeof(IEnumerable<>).MakeGenericType(types[1])), + typeof(Func<,,>).MakeGenericType(types[0], types[1], types[2]) + ]); + + SequenceEqual = typeof(MemoryExtensions).GetMethods(BindingFlags.Static | BindingFlags.Public) + .First(m => m is { Name: nameof(MemoryExtensions.SequenceEqual), IsGenericMethodDefinition: true }); + + SingleWithoutPredicate = GetMethod( + nameof(Enumerable.Single), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + SingleWithPredicate = GetMethod( + nameof(Enumerable.Single), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + SingleOrDefaultWithoutPredicate = GetMethod( + nameof(Enumerable.SingleOrDefault), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + SingleOrDefaultWithPredicate = GetMethod( + nameof(Enumerable.SingleOrDefault), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + Skip = GetMethod( + nameof(Enumerable.Skip), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(int)]); + + SkipWhile = GetMethod( + nameof(Enumerable.SkipWhile), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + ToArray = GetMethod(nameof(Enumerable.ToArray), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + ToList = GetMethod(nameof(Enumerable.ToList), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); + + Take = GetMethod( + nameof(Enumerable.Take), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(int)]); + + TakeWhile = GetMethod( + nameof(Enumerable.TakeWhile), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + ThenBy = GetMethod( + nameof(Enumerable.ThenBy), 2, + types => [typeof(IOrderedEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + ThenByDescending = GetMethod( + nameof(Enumerable.ThenByDescending), 2, + types => [typeof(IOrderedEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1])]); + + Union = GetMethod( + nameof(Enumerable.Union), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]); + + Where = GetMethod( + nameof(Enumerable.Where), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool))]); + + ZipWithSelector = GetMethod( + nameof(Enumerable.Zip), 3, + types => + [ + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(IEnumerable<>).MakeGenericType(types[1]), + typeof(Func<,,>).MakeGenericType(types[0], types[1], types[2]) + ]); + + var numericTypes = new[] + { + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?) + }; + + // AverageWithoutSelectorMethods = new Dictionary(); + AverageWithSelectorMethods = new Dictionary(); + MaxWithoutSelectorMethods = new Dictionary(); + MaxWithSelectorMethods = new Dictionary(); + MinWithoutSelectorMethods = new Dictionary(); + MinWithSelectorMethods = new Dictionary(); + // SumWithoutSelectorMethods = new Dictionary(); + SumWithSelectorMethods = new Dictionary(); + + foreach (var type in numericTypes) + { + // AverageWithoutSelectorMethods[type] = GetMethod( + // nameof(Enumerable.Average), 0, types => new[] { typeof(IEnumerable<>).MakeGenericType(type) }); + AverageWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Average), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type)]); + MaxWithoutSelectorMethods[type] = GetMethod( + nameof(Enumerable.Max), 0, _ => [typeof(IEnumerable<>).MakeGenericType(type)]); + MaxWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Max), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type)]); + MinWithoutSelectorMethods[type] = GetMethod( + nameof(Enumerable.Min), 0, _ => [typeof(IEnumerable<>).MakeGenericType(type)]); + MinWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Min), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type)]); + // SumWithoutSelectorMethods[type] = GetMethod( + // nameof(Enumerable.Sum), 0, types => new[] { typeof(IEnumerable<>).MakeGenericType(type) }); + SumWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Sum), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type)]); + } + + MethodInfo GetMethod(string name, int genericParameterCount, Func parameterGenerator) + => queryableMethodGroups[name].Single( + mi => ((genericParameterCount == 0 && !mi.IsGenericMethod) + || (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount)) + && mi.GetParameters().Select(e => e.ParameterType).SequenceEqual( + parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : []))); + } +} diff --git a/src/Shared/SharedTypeExtensions.cs b/src/Shared/SharedTypeExtensions.cs index 803525b98c9..dacbe4b35ba 100644 --- a/src/Shared/SharedTypeExtensions.cs +++ b/src/Shared/SharedTypeExtensions.cs @@ -154,7 +154,8 @@ public static Type GetSequenceType([DynamicallyAccessedMembers(DynamicallyAccess public static Type? TryGetSequenceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) => type.TryGetElementType(typeof(IEnumerable<>)) - ?? type.TryGetElementType(typeof(IAsyncEnumerable<>)); + ?? type.TryGetElementType(typeof(IAsyncEnumerable<>)) + ?? type.TryGetElementType(typeof(Span<>)); public static Type? TryGetElementType( [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type, diff --git a/test/EFCore.Design.Tests/Migrations/Design/CSharpMigrationOperationGeneratorTest.cs b/test/EFCore.Design.Tests/Migrations/Design/CSharpMigrationOperationGeneratorTest.cs index ba445da389c..8faaa141f04 100644 --- a/test/EFCore.Design.Tests/Migrations/Design/CSharpMigrationOperationGeneratorTest.cs +++ b/test/EFCore.Design.Tests/Migrations/Design/CSharpMigrationOperationGeneratorTest.cs @@ -2453,7 +2453,7 @@ public void InsertDataOperation_required_empty_array() Assert.Single(o.Columns); Assert.Equal(1, o.Values.GetLength(0)); Assert.Equal(1, o.Values.GetLength(1)); - Assert.Equal([], (string[])o.Values[0, 0]); + Assert.Equal([], ((string[])o.Values[0, 0]).AsSpan()); }); [ConditionalFact] @@ -2478,7 +2478,7 @@ public void InsertDataOperation_required_empty_array_composite() Assert.Equal(1, o.Values.GetLength(0)); Assert.Equal(3, o.Values.GetLength(1)); Assert.Null(o.Values[0, 1]); - Assert.Equal([], (string[])o.Values[0, 2]); + Assert.Equal([], ((string[])o.Values[0, 2]).AsSpan()); }); [ConditionalFact] diff --git a/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs index a36bf7e8667..c231704b759 100644 --- a/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs @@ -441,10 +441,10 @@ public virtual async Task Parameter_collection_of_nullable_strings_Contains_stri await AssertQuery( async, - ss => ss.Set().Where(c => strings.Contains(c.String))); + ss => ss.Set().Where(c => strings.AsEnumerable().Contains(c.String))); await AssertQuery( async, - ss => ss.Set().Where(c => !strings.Contains(c.String))); + ss => ss.Set().Where(c => !strings.AsEnumerable().Contains(c.String))); } [ConditionalTheory] @@ -458,7 +458,7 @@ await AssertQuery( ss => ss.Set().Where(c => strings.Contains(c.NullableString))); await AssertQuery( async, - ss => ss.Set().Where(c => !strings.Contains(c.NullableString))); + ss => ss.Set().Where(c => !strings.AsEnumerable().Contains(c.NullableString))); } // See more nullability-related tests in NullSemanticsQueryTestBase diff --git a/test/EFCore.Tests/Metadata/Internal/EntityTypeTest.BaseType.cs b/test/EFCore.Tests/Metadata/Internal/EntityTypeTest.BaseType.cs index 75dd606cd19..a27afe266c3 100644 --- a/test/EFCore.Tests/Metadata/Internal/EntityTypeTest.BaseType.cs +++ b/test/EFCore.Tests/Metadata/Internal/EntityTypeTest.BaseType.cs @@ -677,7 +677,7 @@ public void Navigations_on_base_type_should_be_inherited() var specialCustomerType = model.AddEntityType(typeof(SpecialCustomer)); Assert.Equal(new[] { "Orders" }, customerType.GetNavigations().Select(p => p.Name).ToArray()); - Assert.Equal([], specialCustomerType.GetNavigations().Select(p => p.Name).ToArray()); + Assert.Equal([], specialCustomerType.GetNavigations().Select(p => p.Name)); specialCustomerType.BaseType = customerType; diff --git a/test/EFCore.Tests/Storage/ValueConversion/BytesToStringConverterTest.cs b/test/EFCore.Tests/Storage/ValueConversion/BytesToStringConverterTest.cs index a29063dc5ab..308e5a907a0 100644 --- a/test/EFCore.Tests/Storage/ValueConversion/BytesToStringConverterTest.cs +++ b/test/EFCore.Tests/Storage/ValueConversion/BytesToStringConverterTest.cs @@ -23,7 +23,7 @@ public void Can_convert_bytes_to_strings() var converter = _bytesToStringConverter.ConvertFromProviderExpression.Compile(); Assert.Equal(new byte[] { 83, 112, 196, 177, 110, 204, 136, 97, 108, 32, 84, 97, 112 }, converter("U3DEsW7MiGFsIFRhcA==")); - Assert.Equal([], converter("")); + Assert.Equal([], converter("").AsEnumerable()); } [ConditionalFact] diff --git a/test/EFCore.Tests/Storage/ValueConversion/StringToBytesConverterTest.cs b/test/EFCore.Tests/Storage/ValueConversion/StringToBytesConverterTest.cs index f2aa07d914c..61ada487c83 100644 --- a/test/EFCore.Tests/Storage/ValueConversion/StringToBytesConverterTest.cs +++ b/test/EFCore.Tests/Storage/ValueConversion/StringToBytesConverterTest.cs @@ -13,7 +13,7 @@ public void Can_convert_strings_to_UTF8() var converter = _stringToUtf8Converter.ConvertToProviderExpression.Compile(); Assert.Equal(new byte[] { 83, 112, 196, 177, 110, 204, 136, 97, 108, 32, 84, 97, 112 }, converter("Spın̈al Tap")); - Assert.Equal([], converter("")); + Assert.Equal([], converter("").AsEnumerable()); } [ConditionalFact] From d4d2a638376090fc9e3aeb362340270f947978b5 Mon Sep 17 00:00:00 2001 From: Christopher Jolly Date: Fri, 13 Dec 2024 23:31:08 +0800 Subject: [PATCH 2/3] Some cleanup --- .../CosmosSqlTranslatingExpressionVisitor.cs | 2 +- .../Internal/PrecompiledQueryCodeGenerator.cs | 4 +- ...yExpressionTranslatingExpressionVisitor.cs | 2 +- ...elationalQueryFilterRewritingConvention.cs | 2 +- .../ByteArraySequenceEqualTranslator.cs | 2 +- .../Translators/ContainsTranslator.cs | 2 +- .../RelationalLiftableConstantProcessor.cs | 2 +- ...lationalSqlTranslatingExpressionVisitor.cs | 86 +++++++++---------- .../SqlServerByteArrayMethodTranslator.cs | 2 +- src/EFCore/Query/LiftableConstantProcessor.cs | 2 +- src/Shared/MemoryExtensionsMethods.cs | 8 +- .../Query/GearsOfWarQuerySqlServerTest.cs | 12 +-- .../PrimitiveCollectionsQuerySqlServerTest.cs | 26 +++--- .../Query/TPCGearsOfWarQuerySqlServerTest.cs | 12 +-- .../Query/TPTGearsOfWarQuerySqlServerTest.cs | 12 +-- .../TemporalGearsOfWarQuerySqlServerTest.cs | 12 +-- 16 files changed, 91 insertions(+), 97 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 9663e205dc2..02329b5b826 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -1210,7 +1210,7 @@ private static bool TryEvaluateToConstant(Expression expression, [NotNullWhen(tr { sqlConstantExpression = new SqlConstantExpression( Expression.Lambda>(Expression.Convert(expression, typeof(object))) - .Compile(preferInterpretation: false) + .Compile(preferInterpretation: true) .Invoke(), expression.Type, null); diff --git a/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs index c8309dcdc27..de2957a65d7 100644 --- a/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs +++ b/src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs @@ -179,7 +179,7 @@ public virtual IReadOnlyList GeneratePrecompiledQueries( }; penultimateOperator = Expression.Lambda>(penultimateOperator) - .Compile(preferInterpretation: false)().Expression; + .Compile(preferInterpretation: true)().Expression; // Pass the query through EF's query pipeline; this returns the query's executor function, which can produce an enumerable // that invokes the query. @@ -817,7 +817,7 @@ void GenerateCapturedVariableExtractors( code .Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",") .AppendLine($"Expression.Lambda>(Expression.Convert({variableName}, typeof(object)))") - .AppendLine(".Compile(preferInterpretation: false)") + .AppendLine(".Compile(preferInterpretation: true)") .AppendLine(".Invoke());"); } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index f77d9a63daa..1e09a967659 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -1493,7 +1493,7 @@ when CanEvaluate(memberInitExpression): private static ConstantExpression GetValue(Expression expression) => Expression.Constant( Expression.Lambda>(Expression.Convert(expression, typeof(object))) - .Compile(preferInterpretation: false) + .Compile(preferInterpretation: true) .Invoke(), expression.Type); diff --git a/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs b/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs index d7014bd0573..8453e441377 100644 --- a/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs +++ b/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterRewritingConvention.cs @@ -85,7 +85,7 @@ or nameof(RelationalQueryableExtensions.FromSql)) { var formattableString = Expression.Lambda>( Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))) - .Compile(preferInterpretation: false) + .Compile(preferInterpretation: true) .Invoke(); sql = formattableString.Format; diff --git a/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs b/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs index 3492bb5266b..2b4a6baee57 100644 --- a/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/Translators/ByteArraySequenceEqualTranslator.cs @@ -45,7 +45,7 @@ public ByteArraySequenceEqualTranslator(ISqlExpressionFactory sqlExpressionFacto } if (method.IsGenericMethod - && method.GetGenericMethodDefinition().Equals(MemoryExtensionsMethods.SequenceEqual) + && MemoryExtensionsMethods.SequenceEqual.Contains(method.GetGenericMethodDefinition()) && arguments[0].Type == typeof(byte[])) { return _sqlExpressionFactory.Equal(arguments[0], arguments[1]); diff --git a/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs b/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs index 44d666a6586..fbd580e25d1 100644 --- a/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/Translators/ContainsTranslator.cs @@ -45,7 +45,7 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) // Identify static Enumerable.Contains and instance List.Contains if (method.IsGenericMethod - && (method.GetGenericMethodDefinition() == EnumerableMethods.Contains || method.GetGenericMethodDefinition() == MemoryExtensionsMethods.Contains) + && (method.GetGenericMethodDefinition() == EnumerableMethods.Contains || MemoryExtensionsMethods.Contains.Contains(method.GetGenericMethodDefinition())) && ValidateValues(arguments[0])) { (itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[1]), arguments[0]); diff --git a/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs b/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs index 914b839b008..2d0b9fa8c0c 100644 --- a/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs +++ b/src/EFCore.Relational/Query/RelationalLiftableConstantProcessor.cs @@ -36,7 +36,7 @@ protected override ConstantExpression InlineConstant(LiftableConstantExpression if (liftableConstant.ResolverExpression is Expression> resolverExpression) { - var resolver = resolverExpression.Compile(preferInterpretation: false); + var resolver = resolverExpression.Compile(preferInterpretation: true); var value = resolver(_relationalMaterializerLiftableConstantContext); return Expression.Constant(value, liftableConstant.Type); } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 78bdb292ad7..315c33e1daa 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -855,60 +855,54 @@ JsonScalarExpression jsonScalar .ParameterizedCollectionTranslationMode; var tableAlias = _sqlAliasManager.GenerateTableAlias(sqlParameterExpression.Name.TrimStart('_')); - if (sqlParameterExpression.ShouldBeConstantized - || (primitiveCollectionsBehavior == ParameterizedCollectionTranslationMode.Constantize)) + var elementType = sqlParameterExpression.Type.GetSequenceType(); + var elementTypeMapping = Dependencies.TypeMappingSource.FindMapping(elementType); + if (sqlParameterExpression.ShouldBeConstantized) { - var valuesExpression = new ValuesExpression( - tableAlias, - sqlParameterExpression, - [RelationalQueryableMethodTranslatingExpressionVisitor.ValuesOrderingColumnName, RelationalQueryableMethodTranslatingExpressionVisitor.ValuesValueColumnName]); - return CreateShapedQueryExpressionForValuesExpression( - valuesExpression, - tableAlias, - sqlParameterExpression.TypeMapping!.ElementTypeMapping!.GetType(), - sqlParameterExpression.TypeMapping, - sqlParameterExpression.IsNullable); + return _sqlExpressionFactory.In((SqlExpression)item, sqlParameterExpression); } - - var param = sqlParameterExpression; - if (sqlParameterExpression.Type.IsGenericType && sqlParameterExpression.Type.GetGenericTypeDefinition() == typeof(Span<>)) + else { - var newElement = sqlParameterExpression.Type.GetSequenceType(); - param = new SqlParameterExpression( - sqlParameterExpression.Name, newElement.MakeArrayType(), sqlParameterExpression.IsNullable, - sqlParameterExpression.ShouldBeConstantized, sqlParameterExpression.TypeMapping); - } + var param = sqlParameterExpression; + if (sqlParameterExpression.Type.IsGenericType && sqlParameterExpression.Type.GetGenericTypeDefinition() == typeof(Span<>)) + { + var newElement = sqlParameterExpression.Type.GetSequenceType(); + param = new SqlParameterExpression( + sqlParameterExpression.Name, newElement.MakeArrayType(), sqlParameterExpression.IsNullable, + sqlParameterExpression.ShouldBeConstantized, sqlParameterExpression.TypeMapping); + } - var primitiveresult = TranslatePrimitiveCollection(param, property: null, tableAlias); - var shaperExpression = primitiveresult?.ShaperExpression; - // No need to check ConvertChecked since this is convert node which we may have added during projection - if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression - && unaryExpression.Operand.Type.IsNullableType() - && unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type) - { - shaperExpression = unaryExpression.Operand; - } - if (primitiveresult?.QueryExpression is SelectExpression selectExpression - && shaperExpression is ProjectionBindingExpression projectionBindingExpression - && selectExpression.GetProjection(projectionBindingExpression) is SqlExpression projection) - { - // Translate to IN with a subquery. - // Note that because of null semantics, this may get transformed to an EXISTS subquery in SqlNullabilityProcessor. - var subquery = (SelectExpression)primitiveresult.QueryExpression; - if (subquery.Limit == null - && subquery.Offset == null) + var primitiveresult = TranslatePrimitiveCollection(param, property: null, tableAlias); + var shaperExpression = primitiveresult?.ShaperExpression; + // No need to check ConvertChecked since this is convert node which we may have added during projection + if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand.Type.IsNullableType() + && unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type) { - subquery.ClearOrdering(); + shaperExpression = unaryExpression.Operand; } + if (primitiveresult?.QueryExpression is SelectExpression selectExpression + && shaperExpression is ProjectionBindingExpression projectionBindingExpression + && selectExpression.GetProjection(projectionBindingExpression) is SqlExpression projection) + { + // Translate to IN with a subquery. + // Note that because of null semantics, this may get transformed to an EXISTS subquery in SqlNullabilityProcessor. + var subquery = (SelectExpression)primitiveresult.QueryExpression; + if (subquery.Limit == null + && subquery.Offset == null) + { + subquery.ClearOrdering(); + } - subquery.IsDistinct = false; + subquery.IsDistinct = false; - subquery.ReplaceProjection(new List { projection }); - subquery.ApplyProjection(); + subquery.ReplaceProjection(new List { projection }); + subquery.ApplyProjection(); - var translation1 = _sqlExpressionFactory.In((SqlExpression)item, subquery); - subquery = new SelectExpression(translation1, _sqlAliasManager); - return translation1; + var translation1 = _sqlExpressionFactory.In((SqlExpression)item, subquery); + subquery = new SelectExpression(translation1, _sqlAliasManager); + return translation1; + } } } @@ -1873,7 +1867,7 @@ private static bool TryEvaluateToConstant(Expression expression, [NotNullWhen(tr { sqlConstantExpression = new SqlConstantExpression( Expression.Lambda>(Expression.Convert(expression, typeof(object))) - .Compile(preferInterpretation: false) + .Compile(preferInterpretation: true) .Invoke(), expression.Type, typeMapping: null); diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs index 73ac4638367..56ba1b820af 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs @@ -63,7 +63,7 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac } if (method.IsGenericMethod - && method.GetGenericMethodDefinition() == MemoryExtensionsMethods.Contains + && MemoryExtensionsMethods.Contains.Contains(method.GetGenericMethodDefinition()) && (arguments[0].Type == typeof(Span) || arguments[0].Type == typeof(byte[]))) { var source = arguments[0]; diff --git a/src/EFCore/Query/LiftableConstantProcessor.cs b/src/EFCore/Query/LiftableConstantProcessor.cs index 2a76d78a94c..5e384638cd8 100644 --- a/src/EFCore/Query/LiftableConstantProcessor.cs +++ b/src/EFCore/Query/LiftableConstantProcessor.cs @@ -198,7 +198,7 @@ protected virtual ConstantExpression InlineConstant(LiftableConstantExpression l // Make sure there aren't any problematic un-lifted constants within the resolver expression. _unsupportedConstantChecker.Check(resolverExpression); - var resolver = resolverExpression.Compile(preferInterpretation: false); + var resolver = resolverExpression.Compile(preferInterpretation: true); var value = resolver(_materializerLiftableConstantContext); return Expression.Constant(value, liftableConstant.Type); diff --git a/src/Shared/MemoryExtensionsMethods.cs b/src/Shared/MemoryExtensionsMethods.cs index 275b9751abd..a58bf799ca8 100644 --- a/src/Shared/MemoryExtensionsMethods.cs +++ b/src/Shared/MemoryExtensionsMethods.cs @@ -27,7 +27,7 @@ internal static class MemoryExtensionsMethods public static MethodInfo Concat { get; } - public static MethodInfo Contains { get; } + public static List Contains { get; } //public static MethodInfo ContainsWithComparer { get; } @@ -81,7 +81,7 @@ internal static class MemoryExtensionsMethods //public static MethodInfo SelectManyWithCollectionSelectorOrdinal { get; } - public static MethodInfo SequenceEqual { get; } + public static List SequenceEqual { get; } //public static MethodInfo SequenceEqualWithComparer { get; } @@ -230,7 +230,7 @@ static MemoryExtensionsMethods() types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]); Contains = typeof(MemoryExtensions).GetMethods(BindingFlags.Static | BindingFlags.Public) - .First(m => m is { Name: nameof(MemoryExtensions.Contains), IsGenericMethodDefinition: true }); + .Where(m => m is { Name: nameof(MemoryExtensions.Contains), IsGenericMethodDefinition: true }).ToList(); CountWithoutPredicate = GetMethod( nameof(Enumerable.Count), 1, @@ -311,7 +311,7 @@ static MemoryExtensionsMethods() ]); SequenceEqual = typeof(MemoryExtensions).GetMethods(BindingFlags.Static | BindingFlags.Public) - .First(m => m is { Name: nameof(MemoryExtensions.SequenceEqual), IsGenericMethodDefinition: true }); + .Where(m => m is { Name: nameof(MemoryExtensions.SequenceEqual), IsGenericMethodDefinition: true }).ToList(); SingleWithoutPredicate = GetMethod( nameof(Enumerable.Single), 1, types => [typeof(IEnumerable<>).MakeGenericType(types[0])]); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs index 5de2aa4b63a..e825cb050e8 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs @@ -10275,12 +10275,12 @@ public override async Task Nav_expansion_with_member_pushdown_inside_Contains_ar SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank] FROM [Gears] AS [g] WHERE ( - SELECT TOP(1) [w0].[Name] - FROM [Weapons] AS [w0] - WHERE [g].[FullName] = [w0].[OwnerFullName] - ORDER BY [w0].[Id]) IN ( - SELECT [w].[value] - FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w] + SELECT TOP(1) [w].[Name] + FROM [Weapons] AS [w] + WHERE [g].[FullName] = [w].[OwnerFullName] + ORDER BY [w].[Id]) IN ( + SELECT [w0].[value] + FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w0] ) """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs index 3575a954aa7..eb172433f4e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs @@ -694,24 +694,24 @@ public override async Task Parameter_collection_of_nullable_strings_Contains_str AssertSql( """ -@__strings_0='["10",null]' (Size = 4000) +@__AsEnumerable_0='["10",null]' (Size = 4000) SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] WHERE [p].[String] IN ( - SELECT [s].[value] - FROM OPENJSON(@__strings_0) WITH ([value] nvarchar(max) '$') AS [s] + SELECT [a].[value] + FROM OPENJSON(@__AsEnumerable_0) WITH ([value] nvarchar(max) '$') AS [a] ) """, // """ -@__strings_0_without_nulls='["10"]' (Size = 4000) +@__AsEnumerable_0_without_nulls='["10"]' (Size = 4000) SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] WHERE [p].[String] NOT IN ( - SELECT [s].[value] - FROM OPENJSON(@__strings_0_without_nulls) AS [s] + SELECT [a].[value] + FROM OPENJSON(@__AsEnumerable_0_without_nulls) AS [a] ) """); } @@ -733,13 +733,13 @@ FROM OPENJSON(@__strings_0_without_nulls) AS [s] """, // """ -@__strings_0_without_nulls='["999"]' (Size = 4000) +@__AsEnumerable_0_without_nulls='["999"]' (Size = 4000) SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] WHERE [p].[NullableString] NOT IN ( - SELECT [s].[value] - FROM OPENJSON(@__strings_0_without_nulls) AS [s] + SELECT [a].[value] + FROM OPENJSON(@__AsEnumerable_0_without_nulls) AS [a] ) AND [p].[NullableString] IS NOT NULL """); } @@ -2086,20 +2086,20 @@ public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapp AssertSql( """ -@__ints_0='[1,2,3]' (Size = 4000) -@__strings_1='["one","two","three"]' (Size = 4000) +@__ints_1='[1,2,3]' (Size = 4000) +@__strings_0='["one","two","three"]' (Size = 4000) SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] WHERE CASE WHEN [p].[Int] IN ( SELECT [i].[value] - FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i] + FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i] ) THEN N'one' ELSE N'two' END IN ( SELECT [s].[value] - FROM OPENJSON(@__strings_1) WITH ([value] nvarchar(max) '$') AS [s] + FROM OPENJSON(@__strings_0) WITH ([value] nvarchar(max) '$') AS [s] ) """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TPCGearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TPCGearsOfWarQuerySqlServerTest.cs index 528e501d582..e62b9cd8b76 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TPCGearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TPCGearsOfWarQuerySqlServerTest.cs @@ -13533,12 +13533,12 @@ UNION ALL FROM [Officers] AS [o] ) AS [u] WHERE ( - SELECT TOP(1) [w0].[Name] - FROM [Weapons] AS [w0] - WHERE [u].[FullName] = [w0].[OwnerFullName] - ORDER BY [w0].[Id]) IN ( - SELECT [w].[value] - FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w] + SELECT TOP(1) [w].[Name] + FROM [Weapons] AS [w] + WHERE [u].[FullName] = [w].[OwnerFullName] + ORDER BY [w].[Id]) IN ( + SELECT [w0].[value] + FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w0] ) """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TPTGearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TPTGearsOfWarQuerySqlServerTest.cs index 696079c3a30..def1f9d2f9b 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TPTGearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TPTGearsOfWarQuerySqlServerTest.cs @@ -11558,12 +11558,12 @@ END AS [Discriminator] FROM [Gears] AS [g] LEFT JOIN [Officers] AS [o] ON [g].[Nickname] = [o].[Nickname] AND [g].[SquadId] = [o].[SquadId] WHERE ( - SELECT TOP(1) [w0].[Name] - FROM [Weapons] AS [w0] - WHERE [g].[FullName] = [w0].[OwnerFullName] - ORDER BY [w0].[Id]) IN ( - SELECT [w].[value] - FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w] + SELECT TOP(1) [w].[Name] + FROM [Weapons] AS [w] + WHERE [g].[FullName] = [w].[OwnerFullName] + ORDER BY [w].[Id]) IN ( + SELECT [w0].[value] + FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w0] ) """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TemporalGearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TemporalGearsOfWarQuerySqlServerTest.cs index 10082b2c7e2..6b75444c05b 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TemporalGearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TemporalGearsOfWarQuerySqlServerTest.cs @@ -10180,12 +10180,12 @@ public override async Task Nav_expansion_with_member_pushdown_inside_Contains_ar SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[PeriodEnd], [g].[PeriodStart], [g].[Rank] FROM [Gears] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [g] WHERE ( - SELECT TOP(1) [w0].[Name] - FROM [Weapons] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [w0] - WHERE [g].[FullName] = [w0].[OwnerFullName] - ORDER BY [w0].[Id]) IN ( - SELECT [w].[value] - FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w] + SELECT TOP(1) [w].[Name] + FROM [Weapons] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [w] + WHERE [g].[FullName] = [w].[OwnerFullName] + ORDER BY [w].[Id]) IN ( + SELECT [w0].[value] + FROM OPENJSON(@__weapons_0) WITH ([value] nvarchar(max) '$') AS [w0] ) """); } From 3b8537ea929dda6e417786bda8d157621da74a2f Mon Sep 17 00:00:00 2001 From: Christopher Jolly Date: Sun, 15 Dec 2024 16:20:39 +0800 Subject: [PATCH 3/3] remove unneeded --- ...yExpressionTranslatingExpressionVisitor.cs | 12 +++++----- ...lationalSqlTranslatingExpressionVisitor.cs | 9 ++++---- .../Internal/ExpressionExtensions.cs | 4 +--- .../Internal/ExpressionTreeFuncletizer.cs | 22 ++----------------- src/EFCore/Query/QueryRootProcessor.cs | 13 +---------- .../ShapedQueryCompilingExpressionVisitor.cs | 13 ++++++----- src/Shared/ExpressionExtensions.cs | 4 +--- src/Shared/SharedTypeExtensions.cs | 3 +-- ...orthwindMiscellaneousQuerySqlServerTest.cs | 16 +++++++++----- 9 files changed, 34 insertions(+), 62 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 1e09a967659..660d0c79998 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -196,13 +196,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) : anySubquery); } - static Expression? RemoveConvert(Expression? expression) - => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression - ? RemoveConvert(unaryExpression.Operand) - : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? - RemoveConvert(methodCallExpression.Object) - : expression; -} + static Expression RemoveConvert(Expression e) + => e is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary + ? RemoveConvert(unary.Operand) + : e; + } } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 315c33e1daa..c968e07c828 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -651,17 +651,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result,out var property)) + if (QueryHelpers.IsMemberAccess(methodCallExpression, Dependencies.Model, out _) && TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result,out var property)) { - /*if (property is IProperty { IsPrimitiveCollection: true } regularProperty + if (property is IProperty { IsPrimitiveCollection: true } regularProperty && result is SqlExpression sqlExpression && TranslatePrimitiveCollection( sqlExpression, regularProperty, _sqlAliasManager.GenerateTableAlias(GenerateTableAlias(sqlExpression))) is { } primitiveCollectionTranslation) { return primitiveCollectionTranslation; - }*/ - return result; + } + + return QueryCompilationContext.NotTranslatedExpression; string GenerateTableAlias(SqlExpression sqlExpression) => sqlExpression switch diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index 0624e0a9eea..4eb3c2e723b 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -201,9 +201,7 @@ public static bool IsLogicalNot(this UnaryExpression sqlUnaryExpression) private static Expression? RemoveConvert(Expression? expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression ? RemoveConvert(unaryExpression.Operand) - : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? - RemoveConvert(methodCallExpression.Object) - : expression; + : expression; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs index 97b4b5af67d..b90cc3c4207 100644 --- a/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs +++ b/src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs @@ -1854,19 +1854,7 @@ private static StateType CombineStateTypes(StateType stateType1, StateType state return result; } - object? value = null; - string? parameterName = null; - bool isContextAccessor = false; - if (evaluatableRoot is MethodCallExpression { Method.Name: "op_Implicit" } evaluatableRootMethod) - { - //can we get the arguments state so that we can do notEvaluatableAsRootHandler - value = Evaluate(evaluatableRootMethod.Arguments[0], out parameterName, out isContextAccessor); - } - else - { - value = Evaluate(evaluatableRoot, out parameterName, out isContextAccessor); - } - + var value = Evaluate(evaluatableRoot, out var parameterName, out var isContextAccessor); switch (value) { @@ -2058,9 +2046,7 @@ bool PreserveConvertNode(Expression expression) static Expression? RemoveConvert(Expression? expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression ? RemoveConvert(unaryExpression.Operand) - : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? - RemoveConvert(methodCallExpression.Object) - : expression; + : expression; } switch (expression) @@ -2092,10 +2078,6 @@ bool PreserveConvertNode(Expression expression) case MethodCallExpression methodCallExpression: parameterName = methodCallExpression.Method.Name; - if (parameterName == "op_Implicit") - { - return expression; - } break; case UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression diff --git a/src/EFCore/Query/QueryRootProcessor.cs b/src/EFCore/Query/QueryRootProcessor.cs index ab2dc888424..34a645d35ca 100644 --- a/src/EFCore/Query/QueryRootProcessor.cs +++ b/src/EFCore/Query/QueryRootProcessor.cs @@ -114,18 +114,7 @@ when listInit.Type.TryGetElementType(typeof(IList<>)) is not null && listInit.Initializers.All(x => x.Arguments.Count == 1) && ShouldConvertToInlineQueryRoot(listInit): return new InlineQueryRootExpression(listInit.Initializers.Select(x => x.Arguments[0]).ToList(), elementClrType); - case MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression: - { - if (elementClrType.GetGenericTypeDefinition() == typeof(Span<>)) - { - var baseType = elementClrType.GetGenericArguments()[0]; - var newArgument = VisitQueryRootCandidate(methodCallExpression.Arguments[0], baseType); - return newArgument == methodCallExpression.Arguments[0] - ? methodCallExpression - : Expression.Call(methodCallExpression.Method, newArgument); - } - goto default; - } + default: return Visit(expression); } diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index 69584cc6fb8..fbc596366bf 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -344,11 +344,14 @@ protected override Expression VisitExtension(Expression extensionExpression) : base.VisitExtension(extensionExpression); private static Expression? RemoveConvert(Expression? expression) - => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression - ? RemoveConvert(unaryExpression.Operand) - : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? - RemoveConvert(methodCallExpression.Object) - : expression; + { + while (expression is { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked }) + { + expression = RemoveConvert(((UnaryExpression)expression).Operand); + } + + return expression; + } } private sealed class EntityMaterializerInjectingExpressionVisitor( diff --git a/src/Shared/ExpressionExtensions.cs b/src/Shared/ExpressionExtensions.cs index aaddf4e98dd..775be040606 100644 --- a/src/Shared/ExpressionExtensions.cs +++ b/src/Shared/ExpressionExtensions.cs @@ -43,9 +43,7 @@ public static LambdaExpression UnwrapLambdaFromQuote(this Expression expression) private static Expression? RemoveConvert(Expression? expression) => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression ? RemoveConvert(unaryExpression.Operand) - : expression is MethodCallExpression { Method.Name: "op_Implicit" } methodCallExpression ? - RemoveConvert(methodCallExpression.Object) - : expression; + : expression; public static T GetConstantValue(this Expression expression) => RemoveConvert(expression) switch diff --git a/src/Shared/SharedTypeExtensions.cs b/src/Shared/SharedTypeExtensions.cs index dacbe4b35ba..803525b98c9 100644 --- a/src/Shared/SharedTypeExtensions.cs +++ b/src/Shared/SharedTypeExtensions.cs @@ -154,8 +154,7 @@ public static Type GetSequenceType([DynamicallyAccessedMembers(DynamicallyAccess public static Type? TryGetSequenceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) => type.TryGetElementType(typeof(IEnumerable<>)) - ?? type.TryGetElementType(typeof(IAsyncEnumerable<>)) - ?? type.TryGetElementType(typeof(Span<>)); + ?? type.TryGetElementType(typeof(IAsyncEnumerable<>)); public static Type? TryGetElementType( [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type, diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs index d3b7b4a45bc..29acb99dada 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs @@ -7368,14 +7368,14 @@ public override async Task Contains_over_concatenated_column_and_parameter(bool AssertSql( """ -@__someVariable_0='SomeVariable' (Size = 4000) -@__data_1='["ALFKISomeVariable","ANATRSomeVariable","ALFKIX"]' (Size = 4000) +@__someVariable_1='SomeVariable' (Size = 4000) +@__data_0='["ALFKISomeVariable","ANATRSomeVariable","ALFKIX"]' (Size = 4000) SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE [c].[CustomerID] + @__someVariable_0 IN ( +WHERE [c].[CustomerID] + @__someVariable_1 IN ( SELECT [d].[value] - FROM OPENJSON(@__data_1) WITH ([value] nvarchar(max) '$') AS [d] + FROM OPENJSON(@__data_0) WITH ([value] nvarchar(max) '$') AS [d] ) """); } @@ -7386,11 +7386,15 @@ public override async Task Contains_over_concatenated_parameter_and_constant(boo AssertSql( """ -@__Contains_0='True' +@__p_1='ALFKISomeConstant' (Size = 4000) +@__data_0='["ALFKISomeConstant","ANATRSomeConstant"]' (Size = 4000) SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE @__Contains_0 = CAST(1 AS bit) +WHERE @__p_1 IN ( + SELECT [d].[value] + FROM OPENJSON(@__data_0) WITH ([value] nvarchar(max) '$') AS [d] +) """); }