Add support for IAsyncEnumerable<T> where T is value type (#17154)

* Add support for IAsyncEnumerable<T> where T is value type

Fixes https://github.com/aspnet/AspNetCore/issues/17139
This commit is contained in:
Pranav K 2019-11-18 08:25:08 -08:00 коммит произвёл GitHub
Родитель eec2ce4a71
Коммит d5fd9fc2fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 268 добавлений и 60 удалений

Просмотреть файл

@ -5,7 +5,6 @@ using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Mvc.Core;
@ -17,8 +16,6 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
namespace Microsoft.AspNetCore.Mvc.Infrastructure
#endif
{
using ReaderFunc = Func<IAsyncEnumerable<object>, Task<ICollection>>;
/// <summary>
/// Type that reads an <see cref="IAsyncEnumerable{T}"/> instance into a
/// generic collection instance.
@ -34,8 +31,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
nameof(ReadInternal),
BindingFlags.NonPublic | BindingFlags.Instance);
private readonly ConcurrentDictionary<Type, ReaderFunc> _asyncEnumerableConverters =
new ConcurrentDictionary<Type, ReaderFunc>();
private readonly ConcurrentDictionary<Type, Func<object, Task<ICollection>>> _asyncEnumerableConverters =
new ConcurrentDictionary<Type, Func<object, Task<ICollection>>>();
private readonly MvcOptions _mvcOptions;
/// <summary>
@ -48,37 +45,39 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
}
/// <summary>
/// Reads a <see cref="IAsyncEnumerable{T}"/> into an <see cref="ICollection{T}"/>.
/// Attempts to produces a delagate that reads a <see cref="IAsyncEnumerable{T}"/> into an <see cref="ICollection{T}"/>.
/// </summary>
/// <param name="value">The <see cref="IAsyncEnumerable{T}"/> to read.</param>
/// <returns>The <see cref="ICollection"/>.</returns>
public Task<ICollection> ReadAsync(IAsyncEnumerable<object> value)
/// <param name="type">The type to read.</param>
/// <param name="reader">A delegate that when awaited reads the <see cref="IAsyncEnumerable{T}"/>.</param>
/// <returns><see langword="true" /> when <paramref name="type"/> is an instance of <see cref="IAsyncEnumerable{T}"/>, othwerise <see langword="false"/>.</returns>
public bool TryGetReader(Type type, out Func<object, Task<ICollection>> reader)
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}
var type = value.GetType();
if (!_asyncEnumerableConverters.TryGetValue(type, out var result))
if (!_asyncEnumerableConverters.TryGetValue(type, out reader))
{
var enumerableType = ClosedGenericMatcher.ExtractGenericInterface(type, typeof(IAsyncEnumerable<>));
Debug.Assert(enumerableType != null);
if (enumerableType is null)
{
// Not an IAsyncEnumerable<T>. Cache this result so we avoid reflection the next time we see this type.
reader = null;
_asyncEnumerableConverters.TryAdd(type, reader);
}
else
{
var enumeratedObjectType = enumerableType.GetGenericArguments()[0];
var enumeratedObjectType = enumerableType.GetGenericArguments()[0];
var converter = (Func<object, Task<ICollection>>)Converter
.MakeGenericMethod(enumeratedObjectType)
.CreateDelegate(typeof(Func<object, Task<ICollection>>), this);
var converter = (ReaderFunc)Converter
.MakeGenericMethod(enumeratedObjectType)
.CreateDelegate(typeof(ReaderFunc), this);
_asyncEnumerableConverters.TryAdd(type, converter);
result = converter;
reader = converter;
_asyncEnumerableConverters.TryAdd(type, reader);
}
}
return result(value);
return reader != null;
}
private async Task<ICollection> ReadInternal<T>(IAsyncEnumerable<object> value)
private async Task<ICollection> ReadInternal<T>(object value)
{
var asyncEnumerable = (IAsyncEnumerable<T>)value;
var result = new List<T>();

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
/// </summary>
public class ObjectResultExecutor : IActionResultExecutor<ObjectResult>
{
private readonly AsyncEnumerableReader _asyncEnumerableReader;
private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory;
/// <summary>
/// Creates a new <see cref="ObjectResultExecutor"/>.
@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
WriterFactory = writerFactory.CreateWriter;
Logger = loggerFactory.CreateLogger<ObjectResultExecutor>();
var options = mvcOptions?.Value ?? throw new ArgumentNullException(nameof(mvcOptions));
_asyncEnumerableReader = new AsyncEnumerableReader(options);
_asyncEnumerableReaderFactory = new AsyncEnumerableReader(options);
}
/// <summary>
@ -117,19 +118,19 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
var value = result.Value;
if (value is IAsyncEnumerable<object> asyncEnumerable)
if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader))
{
return ExecuteAsyncEnumerable(context, result, asyncEnumerable);
return ExecuteAsyncEnumerable(context, result, value, reader);
}
return ExecuteAsyncCore(context, result, objectType, value);
}
private async Task ExecuteAsyncEnumerable(ActionContext context, ObjectResult result, IAsyncEnumerable<object> asyncEnumerable)
private async Task ExecuteAsyncEnumerable(ActionContext context, ObjectResult result, object asyncEnumerable, Func<object, Task<ICollection>> reader)
{
Log.BufferingAsyncEnumerable(Logger, asyncEnumerable);
var enumerated = await _asyncEnumerableReader.ReadAsync(asyncEnumerable);
var enumerated = await reader(asyncEnumerable);
await ExecuteAsyncCore(context, result, enumerated.GetType(), enumerated);
}
@ -194,7 +195,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
"Buffering IAsyncEnumerable instance of type '{Type}'.");
}
public static void BufferingAsyncEnumerable(ILogger logger, IAsyncEnumerable<object> asyncEnumerable)
public static void BufferingAsyncEnumerable(ILogger logger, object asyncEnumerable)
=> _bufferingAsyncEnumerable(logger, asyncEnumerable.GetType().FullName, null);
}
}

Просмотреть файл

@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
private readonly JsonOptions _options;
private readonly ILogger<SystemTextJsonResultExecutor> _logger;
private readonly AsyncEnumerableReader _asyncEnumerableReader;
private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory;
public SystemTextJsonResultExecutor(
IOptions<JsonOptions> options,
@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
{
_options = options.Value;
_logger = logger;
_asyncEnumerableReader = new AsyncEnumerableReader(mvcOptions.Value);
_asyncEnumerableReaderFactory = new AsyncEnumerableReader(mvcOptions.Value);
}
public async Task ExecuteAsync(ActionContext context, JsonResult result)
@ -76,10 +76,10 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
try
{
var value = result.Value;
if (value is IAsyncEnumerable<object> asyncEnumerable)
if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader))
{
Log.BufferingAsyncEnumerable(_logger, asyncEnumerable);
value = await _asyncEnumerableReader.ReadAsync(asyncEnumerable);
Log.BufferingAsyncEnumerable(_logger, value);
value = await reader(value);
}
var type = value?.GetType() ?? typeof(object);
@ -154,7 +154,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
_jsonResultExecuting(logger, type, null);
}
public static void BufferingAsyncEnumerable(ILogger logger, IAsyncEnumerable<object> asyncEnumerable)
public static void BufferingAsyncEnumerable(ILogger logger, object asyncEnumerable)
=> _bufferingAsyncEnumerable(logger, asyncEnumerable.GetType().FullName, null);
}
}

Просмотреть файл

@ -5,46 +5,173 @@ using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Options;
using Xunit;
namespace Microsoft.AspNetCore.Mvc.Infrastructure
{
public class AsyncEnumerableReaderTest
{
[Fact]
public async Task ReadAsync_ReadsIAsyncEnumerable()
[Theory]
[InlineData(typeof(Range))]
[InlineData(typeof(IEnumerable<string>))]
[InlineData(typeof(List<int>))]
public void TryGetReader_ReturnsFalse_IfTypeIsNotIAsyncEnumerable(Type type)
{
// Arrange
var options = new MvcOptions();
var reader = new AsyncEnumerableReader(options);
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable = TestEnumerable();
// Act
var result = await reader.ReadAsync(TestEnumerable());
var result = readerFactory.TryGetReader(type, out var reader);
// Assert
var collection = Assert.IsAssignableFrom<ICollection<string>>(result);
Assert.False(result);
}
[Fact]
public async Task TryGetReader_ReturnsReaderForIAsyncEnumerable()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable = TestEnumerable();
// Act
var result = readerFactory.TryGetReader(asyncEnumerable.GetType(), out var reader);
// Assert
Assert.True(result);
var readCollection = await reader(asyncEnumerable);
var collection = Assert.IsAssignableFrom<ICollection<string>>(readCollection);
Assert.Equal(new[] { "0", "1", "2", }, collection);
}
[Fact]
public async Task ReadAsync_ReadsIAsyncEnumerable_ImplementingMultipleAsyncEnumerableInterfaces()
public async Task TryGetReader_ReturnsReaderForIAsyncEnumerableOfValueType()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable = PrimitiveEnumerable();
// Act
var result = readerFactory.TryGetReader(asyncEnumerable.GetType(), out var reader);
// Assert
Assert.True(result);
var readCollection = await reader(asyncEnumerable);
var collection = Assert.IsAssignableFrom<ICollection<int>>(readCollection);
Assert.Equal(new[] { 0, 1, 2, }, collection);
}
[Fact]
public void TryGetReader_ReturnsCachedDelegate()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable1 = TestEnumerable();
var asyncEnumerable2 = TestEnumerable();
// Act
Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader1));
Assert.True(readerFactory.TryGetReader(asyncEnumerable2.GetType(), out var reader2));
// Assert
Assert.Same(reader1, reader2);
}
[Fact]
public void TryGetReader_ReturnsCachedDelegate_WhenTypeImplementsMultipleIAsyncEnumerableContracts()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable1 = new MultiAsyncEnumerable();
var asyncEnumerable2 = new MultiAsyncEnumerable();
// Act
Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader1));
Assert.True(readerFactory.TryGetReader(asyncEnumerable2.GetType(), out var reader2));
// Assert
Assert.Same(reader1, reader2);
}
[Fact]
public async Task CachedDelegate_CanReadEnumerableInstanceMultipleTimes()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable1 = TestEnumerable();
var asyncEnumerable2 = TestEnumerable();
var expected = new[] { "0", "1", "2" };
// Act
Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader));
// Assert
Assert.Equal(expected, await reader(asyncEnumerable1));
Assert.Equal(expected, await reader(asyncEnumerable2));
}
[Fact]
public async Task CachedDelegate_CanReadEnumerableInstanceMultipleTimes_ThatProduceDifferentResults()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable1 = TestEnumerable();
var asyncEnumerable2 = TestEnumerable(4);
// Act
Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader));
// Assert
Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1));
Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2));
}
[Fact]
public void TryGetReader_ReturnsDifferentInstancesForDifferentEnumerables()
{
// Arrange
var options = new MvcOptions();
var readerFactory = new AsyncEnumerableReader(options);
var enumerable1 = TestEnumerable();
var enumerable2 = TestEnumerable2();
// Act
Assert.True(readerFactory.TryGetReader(enumerable1.GetType(), out var reader1));
Assert.True(readerFactory.TryGetReader(enumerable2.GetType(), out var reader2));
// Assert
Assert.NotSame(reader1, reader2);
}
[Fact]
public async Task Reader_ReadsIAsyncEnumerable_ImplementingMultipleAsyncEnumerableInterfaces()
{
// This test ensures the reader does not fail if you have a type that implements IAsyncEnumerable for multiple Ts
// Arrange
var options = new MvcOptions();
var reader = new AsyncEnumerableReader(options);
var readerFactory = new AsyncEnumerableReader(options);
var asyncEnumerable = new MultiAsyncEnumerable();
// Act
var result = await reader.ReadAsync(new MultiAsyncEnumerable());
var result = readerFactory.TryGetReader(asyncEnumerable.GetType(), out var reader);
// Assert
var collection = Assert.IsAssignableFrom<ICollection<object>>(result);
Assert.True(result);
var readCollection = await reader(asyncEnumerable);
var collection = Assert.IsAssignableFrom<ICollection<object>>(readCollection);
Assert.Equal(new[] { "0", "1", "2", }, collection);
}
[Fact]
public async Task ReadAsync_ThrowsIfBufferimitIsReached()
[Fact]
public async Task Reader_ThrowsIfBufferLimitIsReached()
{
// Arrange
var enumerable = TestEnumerable(11);
@ -52,10 +179,11 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
"This limit is in place to prevent infinite streams of 'IAsyncEnumerable<>' from continuing indefinitely. If this is not a programming mistake, " +
$"consider ways to reduce the collection size, or consider manually converting '{enumerable.GetType()}' into a list rather than increasing the limit.";
var options = new MvcOptions { MaxIAsyncEnumerableBufferLimit = 10 };
var reader = new AsyncEnumerableReader(options);
var readerFactory = new AsyncEnumerableReader(options);
// Act
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader.ReadAsync(enumerable));
Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader));
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable));
// Assert
Assert.Equal(expected, ex.Message);
@ -70,6 +198,22 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
}
}
public static async IAsyncEnumerable<string> TestEnumerable2()
{
await Task.Yield();
yield return "Hello";
yield return "world";
}
public static async IAsyncEnumerable<int> PrimitiveEnumerable(int count = 3)
{
await Task.Yield();
for (var i = 0; i < count; i++)
{
yield return i;
}
}
public class MultiAsyncEnumerable : IAsyncEnumerable<object>, IAsyncEnumerable<string>
{
public IAsyncEnumerator<string> GetAsyncEnumerator(CancellationToken cancellationToken = default)

Просмотреть файл

@ -311,6 +311,24 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
Assert.StartsWith("Property 'JsonResult.SerializerSettings' must be an instance of type", ex.Message);
}
[Fact]
public async Task ExecuteAsync_WithNullValue()
{
// Arrange
var expected = Encoding.UTF8.GetBytes("null");
var context = GetActionContext();
var result = new JsonResult(value: null);
var executor = CreateExecutor();
// Act
await executor.ExecuteAsync(context, result);
// Assert
var written = GetWrittenBytes(context.HttpContext);
Assert.Equal(expected, written);
}
[Fact]
public async Task ExecuteAsync_SerializesAsyncEnumerables()
{
@ -329,6 +347,24 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
Assert.Equal(expected, written);
}
[Fact]
public async Task ExecuteAsync_SerializesAsyncEnumerablesOfPrimtives()
{
// Arrange
var expected = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new[] { 1, 2 }));
var context = GetActionContext();
var result = new JsonResult(TestAsyncPrimitiveEnumerable());
var executor = CreateExecutor();
// Act
await executor.ExecuteAsync(context, result);
// Assert
var written = GetWrittenBytes(context.HttpContext);
Assert.Equal(expected, written);
}
protected IActionResultExecutor<JsonResult> CreateExecutor() => CreateExecutor(NullLoggerFactory.Instance);
protected abstract IActionResultExecutor<JsonResult> CreateExecutor(ILoggerFactory loggerFactory);
@ -380,5 +416,12 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
yield return "Hello";
yield return "world";
}
private async IAsyncEnumerable<int> TestAsyncPrimitiveEnumerable()
{
await Task.Yield();
yield return 1;
yield return 2;
}
}
}

Просмотреть файл

@ -361,6 +361,28 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
MediaTypeAssert.Equal(expectedContentType, responseContentType);
}
[Fact]
public async Task ObjectResult_NullValue()
{
// Arrange
var executor = CreateExecutor();
var result = new ObjectResult(value: null);
var formatter = new TestJsonOutputFormatter();
result.Formatters.Add(formatter);
var actionContext = new ActionContext()
{
HttpContext = GetHttpContext(),
};
// Act
await executor.ExecuteAsync(actionContext, result);
// Assert
var formatterContext = formatter.LastOutputFormatterContext;
Assert.Null(formatterContext.Object);
}
[Fact]
public async Task ObjectResult_ReadsAsyncEnumerables()
{

Просмотреть файл

@ -3,7 +3,6 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Mvc.Formatters;
@ -31,7 +30,7 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
private readonly MvcOptions _mvcOptions;
private readonly MvcNewtonsoftJsonOptions _jsonOptions;
private readonly IArrayPool<char> _charPool;
private readonly AsyncEnumerableReader _asyncEnumerableReader;
private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory;
/// <summary>
/// Creates a new <see cref="NewtonsoftJsonResultExecutor"/>.
@ -73,7 +72,7 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
_mvcOptions = mvcOptions?.Value ?? throw new ArgumentNullException(nameof(mvcOptions));
_jsonOptions = jsonOptions.Value;
_charPool = new JsonArrayPool<char>(charPool);
_asyncEnumerableReader = new AsyncEnumerableReader(_mvcOptions);
_asyncEnumerableReaderFactory = new AsyncEnumerableReader(_mvcOptions);
}
/// <summary>
@ -133,10 +132,10 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
var jsonSerializer = JsonSerializer.Create(jsonSerializerSettings);
var value = result.Value;
if (result.Value is IAsyncEnumerable<object> asyncEnumerable)
if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader))
{
Log.BufferingAsyncEnumerable(_logger, asyncEnumerable);
value = await _asyncEnumerableReader.ReadAsync(asyncEnumerable);
Log.BufferingAsyncEnumerable(_logger, value);
value = await reader(value);
}
jsonSerializer.Serialize(jsonWriter, value);
@ -201,7 +200,7 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
_jsonResultExecuting(logger, type, null);
}
public static void BufferingAsyncEnumerable(ILogger logger, IAsyncEnumerable<object> asyncEnumerable)
public static void BufferingAsyncEnumerable(ILogger logger, object asyncEnumerable)
=> _bufferingAsyncEnumerable(logger, asyncEnumerable.GetType().FullName, null);
}
}