Add ThrowForMissingOrderBy feature (#749)

This commit is contained in:
Simon Cropp 2024-08-14 14:44:02 +10:00 коммит произвёл GitHub
Родитель c59e5c9933
Коммит 3aecd03840
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
15 изменённых файлов: 294 добавлений и 32 удалений

Просмотреть файл

@ -77,7 +77,7 @@ builder.UseSqlServer(connection);
builder.EnableRecording();
var data = new SampleDbContext(builder.Options);
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L302-L309' title='Snippet source file'>snippet source</a> | <a href='#snippet-EnableRecording' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L349-L356' title='Snippet source file'>snippet source</a> | <a href='#snippet-EnableRecording' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
`EnableRecording` should only be called in the test context.
@ -106,7 +106,7 @@ await data
await Verify();
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L401-L419' title='Snippet source file'>snippet source</a> | <a href='#snippet-Recording' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L448-L466' title='Snippet source file'>snippet source</a> | <a href='#snippet-Recording' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Will result in the following verified file:
@ -157,7 +157,7 @@ await Verify(
entries
});
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L557-L582' title='Snippet source file'>snippet source</a> | <a href='#snippet-RecordingSpecific' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L604-L629' title='Snippet source file'>snippet source</a> | <a href='#snippet-RecordingSpecific' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
@ -189,7 +189,7 @@ await data2
await Verify();
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L369-L392' title='Snippet source file'>snippet source</a> | <a href='#snippet-MultiDbContexts' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L416-L439' title='Snippet source file'>snippet source</a> | <a href='#snippet-MultiDbContexts' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
<!-- snippet: CoreTests.MultiDbContexts.verified.txt -->
@ -251,7 +251,7 @@ await data
await Verify();
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L428-L451' title='Snippet source file'>snippet source</a> | <a href='#snippet-RecordingDisableForInstance' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L475-L498' title='Snippet source file'>snippet source</a> | <a href='#snippet-RecordingDisableForInstance' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
<!-- snippet: CoreTests.RecordingDisabledTest.verified.txt -->
@ -298,7 +298,7 @@ public async Task Added()
await Verify(data.ChangeTracker);
}
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L5-L21' title='Snippet source file'>snippet source</a> | <a href='#snippet-Added' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L52-L68' title='Snippet source file'>snippet source</a> | <a href='#snippet-Added' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Will result in the following verified file:
@ -343,7 +343,7 @@ public async Task Deleted()
await Verify(data.ChangeTracker);
}
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L23-L42' title='Snippet source file'>snippet source</a> | <a href='#snippet-Deleted' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L70-L89' title='Snippet source file'>snippet source</a> | <a href='#snippet-Deleted' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Will result in the following verified file:
@ -388,7 +388,7 @@ public async Task Modified()
await Verify(data.ChangeTracker);
}
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L44-L64' title='Snippet source file'>snippet source</a> | <a href='#snippet-Modified' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L91-L111' title='Snippet source file'>snippet source</a> | <a href='#snippet-Modified' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Will result in the following verified file:
@ -423,7 +423,7 @@ var queryable = data.Companies
.Where(_ => _.Content == "value");
await Verify(queryable);
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L259-L265' title='Snippet source file'>snippet source</a> | <a href='#snippet-Queryable' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L306-L312' title='Snippet source file'>snippet source</a> | <a href='#snippet-Queryable' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Will result in the following verified file:
@ -481,7 +481,7 @@ await Verify(data.AllData())
serializer =>
serializer.TypeNameHandling = TypeNameHandling.Objects);
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L238-L245' title='Snippet source file'>snippet source</a> | <a href='#snippet-AllData' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L285-L292' title='Snippet source file'>snippet source</a> | <a href='#snippet-AllData' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Will result in the following verified file with all data in the database:
@ -564,7 +564,7 @@ public async Task IgnoreNavigationProperties()
.IgnoreNavigationProperties();
}
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L66-L88' title='Snippet source file'>snippet source</a> | <a href='#snippet-IgnoreNavigationProperties' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L113-L135' title='Snippet source file'>snippet source</a> | <a href='#snippet-IgnoreNavigationProperties' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
@ -577,7 +577,7 @@ var options = DbContextOptions();
using var data = new SampleDbContext(options);
VerifyEntityFramework.IgnoreNavigationProperties();
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L116-L122' title='Snippet source file'>snippet source</a> | <a href='#snippet-IgnoreNavigationPropertiesGlobal' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L163-L169' title='Snippet source file'>snippet source</a> | <a href='#snippet-IgnoreNavigationPropertiesGlobal' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
@ -598,7 +598,7 @@ protected override void ConfigureWebHost(IWebHostBuilder webBuilder)
_ => dataBuilder.Options));
}
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L511-L523' title='Snippet source file'>snippet source</a> | <a href='#snippet-EnableRecordingWithIdentifier' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L558-L570' title='Snippet source file'>snippet source</a> | <a href='#snippet-EnableRecordingWithIdentifier' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
Then use the same identifier for recording:
@ -614,7 +614,7 @@ var companies = await httpClient.GetFromJsonAsync<Company[]>("/companies");
var entries = Recording.Stop(testName);
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L484-L494' title='Snippet source file'>snippet source</a> | <a href='#snippet-RecordWithIdentifier' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L531-L541' title='Snippet source file'>snippet source</a> | <a href='#snippet-RecordWithIdentifier' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->
The results will not be automatically included in verified file so it will have to be verified manually:
@ -629,7 +629,7 @@ await Verify(
sql = entries
});
```
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L496-L505' title='Snippet source file'>snippet source</a> | <a href='#snippet-VerifyRecordedCommandsWithIdentifier' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/Verify.EntityFramework.Tests/CoreTests.cs#L543-L552' title='Snippet source file'>snippet source</a> | <a href='#snippet-VerifyRecordedCommandsWithIdentifier' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->

Просмотреть файл

@ -10,7 +10,7 @@
<ResolveAssemblyReferencesSilent>true</ResolveAssemblyReferencesSilent>
<NuGetAuditMode>all</NuGetAuditMode>
<NuGetAuditLevel>low</NuGetAuditLevel>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<TreatWarningsAsErrors>false</TreatWarningsAsErrors>
<EnforceCodeStyleInBuild>true</EnforceCodeStyleInBuild>
</PropertyGroup>
</Project>

Просмотреть файл

@ -13,6 +13,7 @@
<PackageVersion Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.8" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="8.0.8" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="8.0.8" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.8" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageVersion Include="NUnit" Version="4.1.0" />
<PackageVersion Include="NUnit3TestAdapter" Version="4.6.0" />

Просмотреть файл

@ -0,0 +1,8 @@
{
Type: Exception,
Message:
SelectExpression must have at least one ordering.
Expression:
SELECT c.Id, c.Content
FROM Companies AS c
}

Просмотреть файл

@ -0,0 +1,7 @@
{
Type: Exception,
Message:
TableExpression must have at least one ordering.
Expression:
Employees AS e
}

Просмотреть файл

@ -0,0 +1,40 @@
[
{
Id: 1,
Content: Company1,
Employees: [
{
Id: 2,
CompanyId: 1,
Content: Employee1,
Age: 25
},
{
Id: 3,
CompanyId: 1,
Content: Employee2,
Age: 31
}
]
},
{
Id: 4,
Content: Company2,
Employees: [
{
Id: 5,
CompanyId: 4,
Content: Employee4,
Age: 34
}
]
},
{
Id: 6,
Content: Company3
},
{
Id: 7,
Content: Company4
}
]

Просмотреть файл

@ -0,0 +1,18 @@
[
{
Id: 1,
Content: Company1
},
{
Id: 4,
Content: Company2
},
{
Id: 6,
Content: Company3
},
{
Id: 7,
Content: Company4
}
]

Просмотреть файл

@ -2,6 +2,53 @@
[Parallelizable(ParallelScope.All)]
public class CoreTests
{
[Test]
public async Task MissingOrderBy()
{
await using var database = await DbContextBuilder.GetOrderRequiredDatabase();
var data = database.Context;
await ThrowsTask(
() => data.Companies
.ToListAsync())
.IgnoreStackTrace();
}
[Test]
public async Task NestedMissingOrderBy()
{
await using var database = await DbContextBuilder.GetOrderRequiredDatabase();
var data = database.Context;
await ThrowsTask(
() => data.Companies
.Include(_ => _.Employees)
.OrderBy(_ => _.Content)
.ToListAsync())
.IgnoreStackTrace();
}
[Test]
public async Task WithOrderBy()
{
await using var database = await DbContextBuilder.GetOrderRequiredDatabase();
var data = database.Context;
await Verify(
data.Companies
.OrderBy(_ => _.Content)
.ToListAsync());
}
[Test]
public async Task WithNestedOrderBy()
{
await using var database = await DbContextBuilder.GetOrderRequiredDatabase();
var data = database.Context;
await Verify(
data.Companies
.Include(_ => _.Employees.OrderBy(_ => _.Age))
.OrderBy(_ => _.Content)
.ToListAsync());
}
#region Added
[Test]
@ -232,7 +279,7 @@ public class CoreTests
[Test]
public async Task AllData()
{
var database = await DbContextBuilder.GetDatabase("AllData");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
#region AllData
@ -248,7 +295,7 @@ public class CoreTests
[Test]
public async Task Queryable()
{
var database = await DbContextBuilder.GetDatabase("Queryable");
var database = await DbContextBuilder.GetDatabase();
await database.AddData(
new Company
{
@ -268,7 +315,7 @@ public class CoreTests
[Test]
public async Task SetSelect()
{
var database = await DbContextBuilder.GetDatabase("SetSelect");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
var query = data
@ -280,7 +327,7 @@ public class CoreTests
[Test]
public async Task NestedQueryable()
{
var database = await DbContextBuilder.GetDatabase("NestedQueryable");
var database = await DbContextBuilder.GetDatabase();
await database.AddData(
new Company
{
@ -312,7 +359,7 @@ public class CoreTests
[Test]
public async Task Parameters()
{
var database = await DbContextBuilder.GetDatabase("Parameters");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
data.Add(
new Company
@ -330,7 +377,7 @@ public class CoreTests
[Test]
public async Task MultiRecording()
{
var database = await DbContextBuilder.GetDatabase("MultiRecording");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
Recording.Start();
var company = new Company
@ -363,7 +410,7 @@ public class CoreTests
[Test]
public async Task MultiDbContexts()
{
var database = await DbContextBuilder.GetDatabase("MultiDbContexts");
var database = await DbContextBuilder.GetDatabase();
var connectionString = database.ConnectionString;
#region MultiDbContexts
@ -395,7 +442,7 @@ public class CoreTests
[Test]
public async Task RecordingTest()
{
var database = await DbContextBuilder.GetDatabase("Recording");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
#region Recording
@ -422,7 +469,7 @@ public class CoreTests
[Test]
public async Task RecordingDisabledTest()
{
var database = await DbContextBuilder.GetDatabase("RecordingDisabledTest");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
#region RecordingDisableForInstance
@ -551,7 +598,7 @@ public class CoreTests
[Test]
public async Task RecordingSpecific()
{
var database = await DbContextBuilder.GetDatabase("RecordingSpecific");
var database = await DbContextBuilder.GetDatabase();
var data = database.Context;
#region RecordingSpecific

Просмотреть файл

@ -3,7 +3,8 @@
public static class DbContextBuilder
{
static DbContextBuilder() =>
static DbContextBuilder()
{
sqlInstance = new(
buildTemplate: CreateDb,
constructInstance: builder =>
@ -11,8 +12,19 @@ public static class DbContextBuilder
builder.EnableRecording();
return new(builder.Options);
});
orderRequiredSqlInstance = new(
buildTemplate: CreateDb,
storage: Storage.FromSuffix<SampleDbContext>("ThrowForMissingOrderBy"),
constructInstance: builder =>
{
builder.EnableRecording();
builder.ThrowForMissingOrderBy();
return new(builder.Options);
});
}
static SqlInstance<SampleDbContext> sqlInstance;
static SqlInstance<SampleDbContext> orderRequiredSqlInstance;
static async Task CreateDb(SampleDbContext data)
{
@ -63,6 +75,9 @@ public static class DbContextBuilder
await data.SaveChangesAsync();
}
public static Task<SqlDatabase<SampleDbContext>> GetDatabase(string suffix)
public static Task<SqlDatabase<SampleDbContext>> GetDatabase([CallerMemberName] string suffix = "")
=> sqlInstance.Build(suffix);
public static Task<SqlDatabase<SampleDbContext>> GetOrderRequiredDatabase([CallerMemberName] string suffix = "")
=> orderRequiredSqlInstance.Build(suffix);
}

Просмотреть файл

@ -1,10 +1,13 @@
global using System.Data;
global using System.Data.Common;
global using System.Diagnostics.CodeAnalysis;
global using System.Linq.Expressions;
global using Argon;
global using Microsoft.EntityFrameworkCore;
global using Microsoft.EntityFrameworkCore.ChangeTracking;
global using Microsoft.EntityFrameworkCore.Diagnostics;
global using Microsoft.EntityFrameworkCore.Metadata;
global using Microsoft.EntityFrameworkCore.Query;
global using Microsoft.EntityFrameworkCore.Query.Internal;
global using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
global using VerifyTests.EntityFramework;

Просмотреть файл

@ -0,0 +1,88 @@
sealed class MissingOrderByVisitor : ExpressionVisitor
{
List<OrderingExpression> orderedExpressions = [];
[return: NotNullIfNotNull(nameof(expression))]
public override Expression? Visit(Expression? expression)
{
if (expression is null)
{
return null;
}
switch (expression)
{
case ShapedQueryExpression shapedQueryExpression:
Visit(shapedQueryExpression.QueryExpression);
return shapedQueryExpression;
case RelationalSplitCollectionShaperExpression splitExpression:
foreach (var table in splitExpression.SelectExpression.Tables)
{
Visit(table);
}
Visit(splitExpression.InnerShaper);
return splitExpression;
case TableExpression tableExpression:
{
foreach (var orderedExpression in orderedExpressions)
{
if (orderedExpression.Expression is ColumnExpression columnExpression)
{
if (columnExpression.Table == tableExpression)
{
return base.Visit(expression);
}
if (columnExpression.Table is PredicateJoinExpressionBase joinExpression)
{
if (joinExpression.Table == tableExpression)
{
return base.Visit(expression);
}
}
}
}
throw new(
$"""
TableExpression must have at least one ordering.
Expression:
{ExpressionPrinter.Print(tableExpression)}
""");
}
case SelectExpression selectExpression:
{
var orderings = selectExpression.Orderings;
if (orderings.Count == 0)
{
throw new(
$"""
SelectExpression must have at least one ordering.
Expression:
{PrintShortSql(selectExpression)}
""");
}
foreach (var ordering in orderings)
{
orderedExpressions.Add(ordering);
}
return base.Visit(expression);
}
case NonQueryExpression nonQueryExpression:
return nonQueryExpression;
default:
return base.Visit(expression);
}
}
[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "PrintShortSql")]
static extern string PrintShortSql(SelectExpression expression);
}

Просмотреть файл

@ -0,0 +1,13 @@
class RelationalFactory : RelationalShapedQueryCompilingExpressionVisitorFactory
{
public RelationalFactory(ShapedQueryCompilingExpressionVisitorDependencies dependencies, RelationalShapedQueryCompilingExpressionVisitorDependencies relationalDependencies) :
base(dependencies, relationalDependencies)
{
}
public override ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext context)
=> new RelationalVisitor(
Dependencies,
RelationalDependencies,
context);
}

Просмотреть файл

@ -0,0 +1,15 @@
class RelationalVisitor :
RelationalShapedQueryCompilingExpressionVisitor
{
public RelationalVisitor(ShapedQueryCompilingExpressionVisitorDependencies dependencies, RelationalShapedQueryCompilingExpressionVisitorDependencies relationalDependencies, QueryCompilationContext context) :
base(dependencies, relationalDependencies, context)
{
}
[return: NotNullIfNotNull("node")]
public override Expression? Visit(Expression? node)
{
new MissingOrderByVisitor().Visit(node);
return base.Visit(node);
}
}

Просмотреть файл

@ -4,6 +4,9 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="System.Text.Json" />
<PackageReference Include="Verify" />
<PackageReference Include="Microsoft.EntityFrameworkCore" />
<PackageReference Include="ProjectDefaults" PrivateAssets="all" />

Просмотреть файл

@ -131,10 +131,18 @@ public static class VerifyEntityFramework
return new(result, [new("sql", sql)]);
}
public static DbContextOptionsBuilder<TContext> ThrowForMissingOrderBy<TContext>(this DbContextOptionsBuilder<TContext> builder)
where TContext : DbContext =>
builder.ReplaceService<IShapedQueryCompilingExpressionVisitorFactory, RelationalFactory>();
public static DbContextOptionsBuilder<TContext> EnableRecording<TContext>(this DbContextOptionsBuilder<TContext> builder)
where TContext : DbContext
=> builder.EnableRecording(null);
public static DbContextOptionsBuilder<TContext> EnableRecording<TContext>(this DbContextOptionsBuilder<TContext> builder, string? identifier)
where TContext : DbContext =>
builder.AddInterceptors(new LogCommandInterceptor(identifier));
static ConcurrentBag<Guid> recordingDisabledContextIds = [];
public static void DisableRecording<TContext>(this TContext context)
@ -144,8 +152,4 @@ public static class VerifyEntityFramework
internal static bool IsRecordingDisabled<TContext>(this TContext context)
where TContext : DbContext =>
recordingDisabledContextIds.Contains(context.ContextId.InstanceId);
public static DbContextOptionsBuilder<TContext> EnableRecording<TContext>(this DbContextOptionsBuilder<TContext> builder, string? identifier)
where TContext : DbContext =>
builder.AddInterceptors(new LogCommandInterceptor(identifier));
}