Additional changes for better min/max inlining

This commit is contained in:
James Terwilliger 2019-05-23 14:08:29 -07:00
Родитель b6a5ecf990
Коммит 26b2f443b4
5 изменённых файлов: 364 добавлений и 24 удалений

Просмотреть файл

@ -5,6 +5,7 @@
using System;
using System.Diagnostics.Contracts;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.StreamProcessing.Internal;
namespace Microsoft.StreamProcessing.Aggregates
@ -12,26 +13,45 @@ namespace Microsoft.StreamProcessing.Aggregates
internal class TumblingMaxAggregate<T> : IAggregate<T, MinMaxState<T>, T>
{
private static readonly long InvalidSyncTime = StreamEvent.MinSyncTime - 1;
private readonly Comparison<T> comparer;
private readonly Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> accumulate;
public TumblingMaxAggregate() : this(ComparerExpression<T>.Default) { }
public TumblingMaxAggregate(IComparerExpression<T> comparer)
{
Contract.Requires(comparer != null);
this.comparer = comparer.GetCompareExpr().Compile();
var stateExpression = Expression.Parameter(typeof(MinMaxState<T>), "state");
var timestampExpression = Expression.Parameter(typeof(long), "timestamp");
var inputExpression = Expression.Parameter(typeof(T), "input");
Expression<Func<MinMaxState<T>>> constructor = () => new MinMaxState<T>();
Expression<Func<MinMaxState<T>, long>> currentTimestamp = (state) => state.currentTimestamp;
Expression<Func<MinMaxState<T>, T>> currentValue = (state) => state.currentValue;
var currentTimestampExpression = currentTimestamp.ReplaceParametersInBody(stateExpression);
var currentValueExpression = currentValue.ReplaceParametersInBody(stateExpression);
var comparerExpression = comparer.GetCompareExpr().ReplaceParametersInBody(inputExpression, currentValueExpression);
var typeInfo = typeof(MinMaxState<T>).GetTypeInfo();
this.accumulate = Expression.Lambda<Func<MinMaxState<T>, long, T, MinMaxState<T>>>(
Expression.MemberInit(
(NewExpression)constructor.Body,
Expression.Bind(typeInfo.GetField("currentTimestamp"), timestampExpression),
Expression.Bind(typeInfo.GetField("currentValue"), Expression.Condition(
Expression.Or(
Expression.Equal(currentTimestampExpression, Expression.Constant(InvalidSyncTime)),
Expression.GreaterThan(comparerExpression, Expression.Constant(0))),
inputExpression,
currentValueExpression))),
stateExpression,
timestampExpression,
inputExpression);
}
public Expression<Func<MinMaxState<T>>> InitialState()
=> () => new MinMaxState<T> { currentTimestamp = InvalidSyncTime };
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Accumulate() =>
(state, timestamp, input) =>
new MinMaxState<T>
{
currentTimestamp = timestamp,
currentValue = (state.currentTimestamp == InvalidSyncTime || this.comparer(input, state.currentValue) > 0) ? input : state.currentValue
};
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Accumulate() => this.accumulate;
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Deaccumulate()
=> (state, timestamp, input) => state; // never invoked, hence not implemented

Просмотреть файл

@ -5,6 +5,7 @@
using System;
using System.Diagnostics.Contracts;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.StreamProcessing.Internal;
namespace Microsoft.StreamProcessing.Aggregates
@ -12,25 +13,45 @@ namespace Microsoft.StreamProcessing.Aggregates
internal class TumblingMinAggregate<T> : IAggregate<T, MinMaxState<T>, T>
{
private static readonly long InvalidSyncTime = StreamEvent.MinSyncTime - 1;
private readonly Comparison<T> comparer;
private readonly Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> accumulate;
public TumblingMinAggregate() : this(ComparerExpression<T>.Default) { }
public TumblingMinAggregate(IComparerExpression<T> comparer)
{
Contract.Requires(comparer != null);
this.comparer = comparer.GetCompareExpr().Compile();
var stateExpression = Expression.Parameter(typeof(MinMaxState<T>), "state");
var timestampExpression = Expression.Parameter(typeof(long), "timestamp");
var inputExpression = Expression.Parameter(typeof(T), "input");
Expression<Func<MinMaxState<T>>> constructor = () => new MinMaxState<T>();
Expression<Func<MinMaxState<T>, long>> currentTimestamp = (state) => state.currentTimestamp;
Expression<Func<MinMaxState<T>, T>> currentValue = (state) => state.currentValue;
var currentTimestampExpression = currentTimestamp.ReplaceParametersInBody(stateExpression);
var currentValueExpression = currentValue.ReplaceParametersInBody(stateExpression);
var comparerExpression = comparer.GetCompareExpr().ReplaceParametersInBody(inputExpression, currentValueExpression);
var typeInfo = typeof(MinMaxState<T>).GetTypeInfo();
this.accumulate = Expression.Lambda<Func<MinMaxState<T>, long, T, MinMaxState<T>>>(
Expression.MemberInit(
(NewExpression)constructor.Body,
Expression.Bind(typeInfo.GetField("currentTimestamp"), timestampExpression),
Expression.Bind(typeInfo.GetField("currentValue"), Expression.Condition(
Expression.Or(
Expression.Equal(currentTimestampExpression, Expression.Constant(InvalidSyncTime)),
Expression.LessThan(comparerExpression, Expression.Constant(0))),
inputExpression,
currentValueExpression))),
stateExpression,
timestampExpression,
inputExpression);
}
public Expression<Func<MinMaxState<T>>> InitialState()
=> () => new MinMaxState<T> { currentTimestamp = InvalidSyncTime };
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Accumulate()
=> (state, timestamp, input) => new MinMaxState<T>
{
currentTimestamp = timestamp,
currentValue = (state.currentTimestamp == InvalidSyncTime || this.comparer(input, state.currentValue) < 0) ? input : state.currentValue
};
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Accumulate() => this.accumulate;
public Expression<Func<MinMaxState<T>, long, T, MinMaxState<T>>> Deaccumulate()
=> (state, timestamp, input) => state; // never invoked, hence not implemented

Просмотреть файл

@ -40,6 +40,9 @@ namespace Microsoft.StreamProcessing
typeComparerCache.Add(typeof(long), new ComparerExpression<long>((x, y) => x < y ? -1 : x == y ? 0 : 1));
typeComparerCache.Add(typeof(ulong), new ComparerExpression<ulong>((x, y) => x < y ? -1 : x == y ? 0 : 1));
typeComparerCache.Add(typeof(string), new ComparerExpression<string>((x, y) => x.CompareTo(y)));
typeComparerCache.Add(typeof(TimeSpan), new ComparerExpression<TimeSpan>((x, y) => x.CompareTo(y)));
typeComparerCache.Add(typeof(DateTime), new ComparerExpression<DateTime>((x, y) => x.CompareTo(y)));
typeComparerCache.Add(typeof(DateTimeOffset), new ComparerExpression<DateTimeOffset>((x, y) => x.CompareTo(y)));
typeComparerCache.Add(typeof(Empty), new ComparerExpression<Empty>((x, y) => 0));
}
@ -121,6 +124,24 @@ namespace Microsoft.StreamProcessing
}
else
{
var genericComparableInterface = type
.GetTypeInfo().GetInterfaces()
.Where(i => i.Namespace.Equals("System.Collections.Generic") && i.Name.Equals("IComparable`1") && i.GetTypeInfo().GetGenericArguments().Length == 1 && i.GetTypeInfo().GetGenericArguments()[0] == type)
.FirstOrDefault();
if (genericComparableInterface != null)
{
// then fall back to using a lambda of the form:
// (x,y) => x.CompareTo(y)
var genericInstanceOfComparerExpressionForGenericIComparable = typeof(ComparerExpressionForGenericIComparable<>).MakeGenericType(type);
var ctorForComparerExpressionForGenericIComparer = genericInstanceOfComparerExpressionForGenericIComparable.GetTypeInfo().GetConstructor(Array.Empty<Type>());
if (ctorForComparerExpressionForGenericIComparer != null)
{
comparer = (IComparerExpression<T>)ctorForComparerExpressionForGenericIComparer.Invoke(Array.Empty<object>());
ComparerExpressionCache.Add(comparer);
return comparer;
}
}
var genericComparerInterface = type
.GetTypeInfo().GetInterfaces()
.Where(i => i.Namespace.Equals("System.Collections.Generic") && i.Name.Equals("IComparer`1") && i.GetTypeInfo().GetGenericArguments().Length == 1 && i.GetTypeInfo().GetGenericArguments()[0] == type)
@ -147,6 +168,7 @@ namespace Microsoft.StreamProcessing
}
}
}
if (type.GetTypeInfo().GetInterface("System.Collections.IComparer") != null)
{
// then fall back to using a lambda of the form:
@ -251,21 +273,23 @@ namespace Microsoft.StreamProcessing
}
}
internal class GenericComparerExpression<T> : ComparerExpression<T>
internal sealed class GenericComparerExpression<T> : ComparerExpression<T>
{
public GenericComparerExpression() : base(compareExpr: (x, y) => Comparer<T>.Default.Compare(x, y)) { }
}
internal class ComparerExpressionForGenericIComparer<T> : ComparerExpression<T> where T : IComparer<T>
internal sealed class ComparerExpressionForGenericIComparable<T> : ComparerExpression<T> where T : IComparable<T>
{
public ComparerExpressionForGenericIComparable() : base(compareExpr: (x, y) => x.CompareTo(y)) { }
}
internal sealed class ComparerExpressionForGenericIComparer<T> : ComparerExpression<T> where T : IComparer<T>
{
public ComparerExpressionForGenericIComparer(T t) : base(compareExpr: (x, y) => t.Compare(x, y)) { }
}
internal class ComparerExpressionForNonGenericIComparer<T> : ComparerExpression<T> where T : IComparer
internal sealed class ComparerExpressionForNonGenericIComparer<T> : ComparerExpression<T> where T : IComparer
{
public ComparerExpressionForNonGenericIComparer(T t)
: base(
compareExpr: (x, y) => t.Compare(x, y))
{ }
public ComparerExpressionForNonGenericIComparer(T t) : base(compareExpr: (x, y) => t.Compare(x, y)) { }
}
}

Просмотреть файл

@ -127,6 +127,33 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void TumblingSnapshot5Row() // like 4, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(10, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot1Row()
{
@ -237,6 +264,34 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot5Row() // like 5, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
StreamEvent.CreateInterval(40, 50, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(20, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void SessionSnapshot1Row()
{
@ -810,6 +865,33 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void TumblingSnapshot5RowSmallBatch() // like 4, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(10, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot1RowSmallBatch()
{
@ -920,6 +1002,34 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot5RowSmallBatch() // like 5, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
StreamEvent.CreateInterval(40, 50, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(20, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void SessionSnapshot1RowSmallBatch()
{
@ -1492,6 +1602,33 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void TumblingSnapshot5Columnar() // like 4, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(10, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot1Columnar()
{
@ -1602,6 +1739,34 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot5Columnar() // like 5, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
StreamEvent.CreateInterval(40, 50, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(20, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void SessionSnapshot1Columnar()
{
@ -2175,6 +2340,33 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void TumblingSnapshot5ColumnarSmallBatch() // like 4, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(10, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot1ColumnarSmallBatch()
{
@ -2285,6 +2477,34 @@ namespace SimpleTesting
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot5ColumnarSmallBatch() // like 5, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
StreamEvent.CreateInterval(40, 50, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(20, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void SessionSnapshot1ColumnarSmallBatch()
{

Просмотреть файл

@ -151,6 +151,33 @@ foreach (var batch in new [] { string.Empty, "SmallBatch" })
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void TumblingSnapshot5<#= suffix #>() // like 4, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(10, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot1<#= suffix #>()
{
@ -261,6 +288,34 @@ foreach (var batch in new [] { string.Empty, "SmallBatch" })
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void HoppingSnapshot5<#= suffix #>() // like 5, but with max
{
var input = new StreamEvent<MyData>[]
{
StreamEvent.CreatePoint(11, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(12, new MyData { field1 = 1, field2 = "A" }),
StreamEvent.CreatePoint(21, new MyData { field1 = 2, field2 = "A" }),
StreamEvent.CreatePoint(25, new MyData { field1 = 2, field2 = "D" })
};
var expected = new StreamEvent<int>[]
{
StreamEvent.CreateInterval(20, 30, 1),
StreamEvent.CreateInterval(30, 40, 2),
StreamEvent.CreateInterval(40, 50, 2),
};
var inputStream = input.ToObservable().ToStreamable();
var query = inputStream.HoppingWindowLifetime(20, 10)
.Max(x => x.field1)
;
var result = query.ToStreamEventObservable(ReshapingPolicy.CoalesceEndEdges).Where(e => e.IsData).ToEnumerable().ToArray();
Assert.IsTrue(result.SequenceEqual(expected));
}
[TestMethod, TestCategory("Gated")]
public void SessionSnapshot1<#= suffix #>()
{