diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 3153a14..dddd331 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -44,7 +44,7 @@ jobs: working-directory: src/FlatSharp.Compiler run: dotnet build -c Debug /p:SignAssembly=false - - name: Run FlatSharp.Compiler + - name: Run FlatSharp.Compiler (E2E Tests) # You may pin to the exact commit or the version. # uses: Amadevus/pwsh-script@97a8b211a5922816aa8a69ced41fa32f23477186 uses: Amadevus/pwsh-script@v2.0.3 @@ -58,7 +58,25 @@ jobs: --use-source-link ` --format opencover ` --target "dotnet" ` + --output e2etests.coverage.xml ` --targetargs "src\FlatSharp.Compiler\bin\Debug\net7.0\FlatSharp.Compiler.dll --nullable-warnings false --normalize-field-names true --input `"$fbs`" -o src/tests/FlatSharpEndToEndTests" + + - name: Run FlatSharp.Compiler (Pooling Tests) + # You may pin to the exact commit or the version. + # uses: Amadevus/pwsh-script@97a8b211a5922816aa8a69ced41fa32f23477186 + uses: Amadevus/pwsh-script@v2.0.3 + with: + # PowerShell script to execute in Actions-hydrated context + script: | + $fbs = (gci -r src/tests/FlatSharpPoolableEndToEndTests/*.fbs) -join ";" + coverlet ` + .\src\FlatSharp.Compiler\bin\Debug\net7.0 ` + --skipautoprops ` + --use-source-link ` + --format opencover ` + --target "dotnet" ` + --output pooling.coverage.xml ` + --targetargs "src\FlatSharp.Compiler\bin\Debug\net7.0\FlatSharp.Compiler.dll --nullable-warnings false --normalize-field-names true --gen-poolable true --input `"$fbs`" -o src/tests/FlatSharpPoolableEndToEndTests" - name: Test working-directory: src diff --git a/.github/workflows/dotnet.yml b/.github/workflows/dotnet.yml index 369c1e8..d76b257 100644 --- a/.github/workflows/dotnet.yml +++ b/.github/workflows/dotnet.yml @@ -33,7 +33,7 @@ jobs: working-directory: src/FlatSharp.Compiler run: dotnet build -c Release /p:SignAssembly=true - - name: Run FlatSharp.Compiler + - name: Run FlatSharp.Compiler (E2E Tests) # You may pin to the exact commit or the version. # uses: Amadevus/pwsh-script@97a8b211a5922816aa8a69ced41fa32f23477186 uses: Amadevus/pwsh-script@v2.0.3 @@ -42,6 +42,17 @@ jobs: script: | $fbs = (gci -r src/tests/FlatsharpEndToEndTests/*.fbs) -join ";" dotnet src/FlatSharp.Compiler/bin/Release/net7.0/FlatSharp.Compiler.dll --nullable-warnings false --normalize-field-names true --input `"$fbs`" -o src/tests/FlatSharpEndToEndTests + + + - name: Run FlatSharp.Compiler (Pooling Tests) + # You may pin to the exact commit or the version. + # uses: Amadevus/pwsh-script@97a8b211a5922816aa8a69ced41fa32f23477186 + uses: Amadevus/pwsh-script@v2.0.3 + with: + # PowerShell script to execute in Actions-hydrated context + script: | + $fbs = (gci -r src/tests/FlatSharpPoolableEndToEndTests/*.fbs) -join ";" + dotnet src/FlatSharp.Compiler/bin/Release/net7.0/FlatSharp.Compiler.dll --nullable-warnings false --normalize-field-names true --gen-poolable true --input `"$fbs`" -o src/tests/FlatSharpPoolableEndToEndTests - name: Build working-directory: src diff --git a/src/Benchmarks/MicroBench.Current/Constants.cs b/src/Benchmarks/MicroBench.Current/Constants.cs index e72203e..2a542f0 100644 --- a/src/Benchmarks/MicroBench.Current/Constants.cs +++ b/src/Benchmarks/MicroBench.Current/Constants.cs @@ -21,6 +21,7 @@ namespace Microbench using System.Diagnostics; using System.Linq; using FlatSharp; + using FlatSharp.Internal; public static class Constants { @@ -30,6 +31,9 @@ namespace Microbench static Constants() { Process.GetCurrentProcess().PriorityClass = ProcessPriorityClass.RealTime; +#if POOLABLE + ObjectPool.MaxToRetain = VectorLength; +#endif } public static class StringTables @@ -63,8 +67,8 @@ namespace Microbench { public static StructsTable SingleRef = new StructsTable { SingleRef = new RefStruct { Value = 1 }, SingleValue = default, }; public static StructsTable SingleValue = new StructsTable { SingleValue = new ValueStruct { Value = 1 } }; - public static StructsTable VectorRef = new StructsTable { VecRef = Enumerable.Range(1, 30).Select(x => new RefStruct { Value = x }).ToList(), SingleValue = default, }; - public static StructsTable VectorValue = new StructsTable { VecValue = Enumerable.Range(1, 30).Select(x => new ValueStruct { Value = x }).ToList(), SingleValue = default, }; + public static StructsTable VectorRef = new StructsTable { VecRef = Enumerable.Range(1, VectorLength).Select(x => new RefStruct { Value = x }).ToList(), SingleValue = default, }; + public static StructsTable VectorValue = new StructsTable { VecValue = Enumerable.Range(1, VectorLength).Select(x => new ValueStruct { Value = x }).ToList(), SingleValue = default, }; } public static class SortedVectorTables diff --git a/src/Benchmarks/MicroBench.Current/Microbench.Current.csproj b/src/Benchmarks/MicroBench.Current/Microbench.Current.csproj index 8f2bb6e..7a9cb06 100644 --- a/src/Benchmarks/MicroBench.Current/Microbench.Current.csproj +++ b/src/Benchmarks/MicroBench.Current/Microbench.Current.csproj @@ -10,6 +10,12 @@ $(DefineConstants);PUBLIC_IVTABLE true $([System.IO.Path]::GetFullPath('$(MSBuildThisFileDirectory)\..\..\FlatSharp.Compiler\bin\$(Configuration)\net7.0\FlatSharp.Compiler.dll')) + false + false + + + + $(DefineConstants);POOLABLE diff --git a/src/Benchmarks/MicroBench.Current/ParseBenchmarks.cs b/src/Benchmarks/MicroBench.Current/ParseBenchmarks.cs index e64c54b..b46be0c 100644 --- a/src/Benchmarks/MicroBench.Current/ParseBenchmarks.cs +++ b/src/Benchmarks/MicroBench.Current/ParseBenchmarks.cs @@ -17,9 +17,7 @@ namespace Microbench { using BenchmarkDotNet.Attributes; - using System; using FlatSharp; - using System.Linq; using System.Collections.Generic; using System.Runtime.CompilerServices; @@ -43,6 +41,8 @@ namespace Microbench length += table.SingleString!.Length; } + table.TryReturnToPool(); + return length; } @@ -74,21 +74,36 @@ namespace Microbench public int Parse_StructTable_SingleRef() { var st = StructsTable.Serializer.Parse(Constants.Buffers.StructTable_SingleRef); - return st.SingleRef!.Value; + var singleRef = st.SingleRef!; + + int result = singleRef.Value; + + //singleRef.TryReturnToPool(); + st.TryReturnToPool(); + + return result; } [Benchmark] public void Parse_StructTable_SingleRef_WriteThrough() { var st = StructsTable.Serializer.Parse(Constants.Buffers.StructTable_SingleRef); - st.SingleRef!.Value = 3; + var singleRef = st.SingleRef!; + singleRef.Value = 3; + + //singleRef.TryReturnToPool(); + st.TryReturnToPool(); } [Benchmark] public int Parse_StructTable_SingleValue() { var st = StructsTable.Serializer.Parse(Constants.Buffers.StructTable_SingleValue); - return st.SingleValue.Value; + var value = st.SingleValue.Value; + + st.TryReturnToPool(); + + return value; } [Benchmark] @@ -96,6 +111,8 @@ namespace Microbench { var st = StructsTable.Serializer.Parse(Constants.Buffers.StructTable_SingleValue); st.SingleValue = new ValueStruct { Value = 3 }; + + st.TryReturnToPool(); } [Benchmark] @@ -109,9 +126,14 @@ namespace Microbench for (int i = 0; i < count; ++i) { - sum += vecRef[i].Value; + var item = vecRef[i]; + sum += item.Value; + //item.TryReturnToPool(); } + //vecRef.TryReturnToPool(); + st.TryReturnToPool(); + return sum; } @@ -125,8 +147,13 @@ namespace Microbench for (int i = 0; i < count; ++i) { - vecRef[i].Value++; + var item = vecRef[i]; + item.Value++; + //item.TryReturnToPool(); } + + //vecRef.TryReturnToPool(); + st.TryReturnToPool(); } [Benchmark] @@ -143,6 +170,9 @@ namespace Microbench sum += vecValue[i].Value; } + //vecValue.TryReturnToPool(); + st.TryReturnToPool(); + return sum; } @@ -160,6 +190,9 @@ namespace Microbench item.Value++; vecValue[i] = item; } + + //vecValue.TryReturnToPool(); + st.TryReturnToPool(); } [Benchmark] @@ -176,6 +209,9 @@ namespace Microbench sum += vector[i].Accept(visitor); } + //vector.TryReturnToPool(); + st.TryReturnToPool(); + return sum; } @@ -193,6 +229,9 @@ namespace Microbench sum += vector[i].Accept(visitor); } + //vector.TryReturnToPool(); + st.TryReturnToPool(); + return sum; } @@ -210,6 +249,9 @@ namespace Microbench sum += vector[i].Accept(visitor); } + //vector.TryReturnToPool(); + st.TryReturnToPool(); + return sum; } @@ -252,6 +294,7 @@ namespace Microbench sum += (int)table.ULong; sum += table.UShort; + table.TryReturnToPool(); return sum; } @@ -275,9 +318,29 @@ namespace Microbench { length += vec[i].Length; } + + //vec.TryReturnToPool(); } + table.TryReturnToPool(); + return length; } } + + public static class Extensions + { +#if POOLABLE + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void TryReturnToPool(this T obj) where T : IPoolableObject + { + obj.ReturnToPool(); + } +#else + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void TryReturnToPool(this T obj) + { + } +#endif + } } diff --git a/src/Benchmarks/MicroBench.Current/Program.cs b/src/Benchmarks/MicroBench.Current/Program.cs index 01364b6..0f18e51 100644 --- a/src/Benchmarks/MicroBench.Current/Program.cs +++ b/src/Benchmarks/MicroBench.Current/Program.cs @@ -55,10 +55,10 @@ namespace Microbench //.AddHardwareCounters(HardwareCounter.BranchInstructions, HardwareCounter.BranchMispredictions) .AddJob(job.DontEnforcePowerPlan()); - summaries.Add(BenchmarkRunner.Run(typeof(SerializeBenchmarks), config)); + //summaries.Add(BenchmarkRunner.Run(typeof(SerializeBenchmarks), config)); summaries.Add(BenchmarkRunner.Run(typeof(ParseBenchmarks), config)); - summaries.Add(BenchmarkRunner.Run(typeof(SortedVectorBenchmarks), config)); - summaries.Add(BenchmarkRunner.Run(typeof(VTableBenchmarks), config)); + //summaries.Add(BenchmarkRunner.Run(typeof(SortedVectorBenchmarks), config)); + //summaries.Add(BenchmarkRunner.Run(typeof(VTableBenchmarks), config)); foreach (var item in summaries) { diff --git a/src/Benchmarks/Microbench.fbs b/src/Benchmarks/Microbench.fbs index 79ac4f8..d846112 100644 --- a/src/Benchmarks/Microbench.fbs +++ b/src/Benchmarks/Microbench.fbs @@ -12,12 +12,12 @@ attribute "fs_unsafeUnion"; namespace Microbench; // Tests reading and writing a string. -table StringTable (fs_serializer:"Lazy") { +table StringTable (fs_serializer:"Progressive") { SingleString : string; Vector : [string]; } -table PrimitivesTable (fs_serializer:"Lazy") { +table PrimitivesTable (fs_serializer:"Progressive") { Bool : bool; Byte : ubyte; SByte : byte; @@ -34,7 +34,7 @@ table PrimitivesTable (fs_serializer:"Lazy") { struct RefStruct (fs_writeThrough) { Value : int; } struct ValueStruct (fs_valueStruct) { Value : int; } -table StructsTable (fs_serializer:"Lazy") +table StructsTable (fs_serializer:"Progressive") { SingleRef : RefStruct; SingleValue : ValueStruct (fs_writeThrough, required); @@ -45,7 +45,7 @@ table StructsTable (fs_serializer:"Lazy") table StringKey { Key : string (key); } table IntKey { Key : int (key); } -table SortedTable (fs_serializer:"Lazy") +table SortedTable (fs_serializer:"Progressive") { Strings : [StringKey] (fs_vector:"IIndexedVector"); Ints : [IntKey] (fs_vector:"IIndexedVector"); @@ -60,7 +60,7 @@ union UnsafeUnion (fs_unsafeUnion) { ValueStructA, ValueStructB, ValueStructC } union SafeUnion { ValueStructA, ValueStructB, ValueStructC } union MixedUnion { ValueStructA, ValueStructB, ValueStructC, Something : string } -table UnionTable (fs_serializer:"Lazy") +table UnionTable (fs_serializer:"Progressive") { Unsafe : [ UnsafeUnion ]; Safe : [ SafeUnion ]; diff --git a/src/FlatSharp.Compiler/CodeWriter.cs b/src/FlatSharp.Compiler/CodeWriter.cs index 50c625f..cb86fbb 100644 --- a/src/FlatSharp.Compiler/CodeWriter.cs +++ b/src/FlatSharp.Compiler/CodeWriter.cs @@ -38,6 +38,11 @@ public class CodeWriter this.builder.AppendLine(line); } + public void AppendInheritDoc() + { + this.AppendLine("/// "); + } + public void AppendSummaryComment(params string[] summaryParts) { this.AppendSummaryComment((IEnumerable)summaryParts); diff --git a/src/FlatSharp.Compiler/CompilerOptions.cs b/src/FlatSharp.Compiler/CompilerOptions.cs index 66aadf6..82ddbe0 100644 --- a/src/FlatSharp.Compiler/CompilerOptions.cs +++ b/src/FlatSharp.Compiler/CompilerOptions.cs @@ -35,6 +35,9 @@ public record CompilerOptions [Option("nullable-warnings", Default = false, HelpText = "Emit full nullable annotations and enable warnings.")] public bool? NullableWarnings { get; set; } + [Option("gen-poolable", Hidden = false, Default = false, HelpText = "EXPERIMENTAL: Generate extra code to enable object pooling for allocation reductions.")] + public bool? GeneratePoolableObjects { get; set; } + [Option("flatc-path", Hidden = true)] public string? FlatcPath { get; set; } diff --git a/src/FlatSharp.Compiler/FlatSharp.Compiler.targets b/src/FlatSharp.Compiler/FlatSharp.Compiler.targets index f945858..ea430c7 100644 --- a/src/FlatSharp.Compiler/FlatSharp.Compiler.targets +++ b/src/FlatSharp.Compiler/FlatSharp.Compiler.targets @@ -95,6 +95,11 @@ true + + false + + + true @@ -112,11 +117,11 @@ diff --git a/src/FlatSharp.Compiler/FlatSharpCompiler.cs b/src/FlatSharp.Compiler/FlatSharpCompiler.cs index 4826fe0..92aca1b 100644 --- a/src/FlatSharp.Compiler/FlatSharpCompiler.cs +++ b/src/FlatSharp.Compiler/FlatSharpCompiler.cs @@ -523,7 +523,7 @@ public class FlatSharpCompiler foreach (var s in bfbs) { - rootModel.UnionWith(ParseSchema(s, options, postProcessTransforms, mutators).ToRootModel()); + rootModel.UnionWith(ParseSchema(s, options, postProcessTransforms, mutators).ToRootModel(options)); } ErrorContext.Current.ThrowIfHasErrors(); diff --git a/src/FlatSharp.Compiler/Schema/Schema.cs b/src/FlatSharp.Compiler/Schema/Schema.cs index 1014c8a..a58d511 100644 --- a/src/FlatSharp.Compiler/Schema/Schema.cs +++ b/src/FlatSharp.Compiler/Schema/Schema.cs @@ -63,7 +63,7 @@ public class Schema [FlatBufferItem(7)] public virtual IIndexedVector? FbsFiles { get; set; } - public RootModel ToRootModel() + public RootModel ToRootModel(CompilerOptions options) { RootModel model = new RootModel(this.AdvancedFeatures); @@ -73,7 +73,7 @@ public class Schema { model.AddElement(enumModel); } - else if (UnionSchemaModel.TryCreate(this, @enum, out var unionModel)) + else if (ValueUnionSchemaModel.TryCreate(this, @enum, options, out var unionModel)) { model.AddElement(unionModel); } diff --git a/src/FlatSharp.Compiler/SchemaModel/BaseReferenceTypeSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/BaseReferenceTypeSchemaModel.cs index 0ad57cf..9db916f 100644 --- a/src/FlatSharp.Compiler/SchemaModel/BaseReferenceTypeSchemaModel.cs +++ b/src/FlatSharp.Compiler/SchemaModel/BaseReferenceTypeSchemaModel.cs @@ -74,13 +74,17 @@ public abstract class BaseReferenceTypeSchemaModel : BaseSchemaModel this.EmitDefaultConstrutor(writer, context); this.EmitDeserializationConstructor(writer); this.EmitCopyConstructor(writer, context); + this.EmitPoolableObject(writer, context); writer.AppendLine("static partial void OnStaticInitialize();"); writer.AppendLine("partial void OnInitialized(FlatBufferDeserializationContext? context);"); - writer.AppendLine($"protected void {TableTypeModel.OnDeserializedMethodName}({nameof(FlatBufferDeserializationContext)}? context) => this.OnInitialized(context);"); - writer.AppendLine(); + writer.AppendLine($"protected void {TableTypeModel.OnDeserializedMethodName}({nameof(FlatBufferDeserializationContext)} context)"); + using (writer.WithBlock()) + { + writer.AppendLine("this.OnInitialized(context);"); + } foreach (var property in this.properties.OrderBy(x => x.Key)) { @@ -167,6 +171,18 @@ public abstract class BaseReferenceTypeSchemaModel : BaseSchemaModel writer.AppendLine("#pragma warning restore CS8618"); // nullable } + private void EmitPoolableObject(CodeWriter writer, CompileContext context) + { + if (context.Options.GeneratePoolableObjects == true) + { + writer.AppendLine("/// "); + writer.AppendLine("public virtual void ReturnToPool(bool unsafeForce = false)"); + using (writer.WithBlock()) + { + } + } + } + protected virtual void EmitStaticConstructor(CodeWriter writer, CompileContext context) { writer.AppendLine($"static {this.Name}()"); diff --git a/src/FlatSharp.Compiler/SchemaModel/FlatBufferSchemaElementType.cs b/src/FlatSharp.Compiler/SchemaModel/FlatBufferSchemaElementType.cs index 2f77e87..7c65d5e 100644 --- a/src/FlatSharp.Compiler/SchemaModel/FlatBufferSchemaElementType.cs +++ b/src/FlatSharp.Compiler/SchemaModel/FlatBufferSchemaElementType.cs @@ -31,4 +31,5 @@ public enum FlatBufferSchemaElementType StructVector = 10, ValueStructVector = 11, RpcCall = 12, + PoolableUnion = 13, } diff --git a/src/FlatSharp.Compiler/SchemaModel/ReferenceStructSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ReferenceStructSchemaModel.cs index 6ec21ac..9c8fdce 100644 --- a/src/FlatSharp.Compiler/SchemaModel/ReferenceStructSchemaModel.cs +++ b/src/FlatSharp.Compiler/SchemaModel/ReferenceStructSchemaModel.cs @@ -76,5 +76,11 @@ public class ReferenceStructSchemaModel : BaseReferenceTypeSchemaModel writer.AppendLine(attribute); writer.AppendLine("[System.Runtime.CompilerServices.CompilerGenerated]"); writer.AppendLine($"public partial class {this.Name}"); + writer.AppendLine($" : object"); + + if (context.Options.GeneratePoolableObjects == true) + { + writer.AppendLine($" , IPoolableObject"); + } } } diff --git a/src/FlatSharp.Compiler/SchemaModel/ReferenceUnionSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ReferenceUnionSchemaModel.cs new file mode 100644 index 0000000..56254f2 --- /dev/null +++ b/src/FlatSharp.Compiler/SchemaModel/ReferenceUnionSchemaModel.cs @@ -0,0 +1,222 @@ +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FlatSharp.Compiler.Schema; +using FlatSharp.CodeGen; +using System; + +namespace FlatSharp.Compiler.SchemaModel; + +public class ReferenceUnionSchemaModel : BaseSchemaModel +{ + private readonly FlatBufferEnum union; + + internal ReferenceUnionSchemaModel(Schema.Schema schema, FlatBufferEnum union) : base(schema, union.Name, new FlatSharpAttributes(union.Attributes)) + { + FlatSharpInternal.Assert(union.UnderlyingType.BaseType == BaseType.UType, "Expecting utype"); + + this.DeclaringFile = union.DeclarationFile; + this.union = union; + } + + public override FlatBufferSchemaElementType ElementType => FlatBufferSchemaElementType.PoolableUnion; + + public override string DeclaringFile { get; } + + protected override void OnWriteCode(CodeWriter writer, CompileContext context) + { + List<(string resolvedType, EnumVal value, Type? propertyType)> innerTypes = new(); + foreach (var inner in this.union.Values.Select(x => x.Value)) + { + // Skip "none". + if (inner.Value == 0) + { + FlatSharpInternal.Assert(inner.Key == "NONE", "Expecting discriminator 0 to be 'None'"); + continue; + } + + FlatSharpInternal.Assert(inner.UnionType is not null, "Union type was null"); + + long discriminator = inner.Value; + string typeName = inner.UnionType.ResolveTypeOrElementTypeName(this.Schema, this.Attributes); + + Type? propertyClrType = null; + if (context.CompilePass > CodeWritingPass.Initialization) + { + Type? previousType = context.PreviousAssembly?.GetType(this.FullName); + FlatSharpInternal.Assert(previousType is not null, "PreviousType was null"); + + propertyClrType = previousType + .GetProperty($"Item{inner.Value}", BindingFlags.Public | BindingFlags.Instance)? + .PropertyType; + + FlatSharpInternal.Assert(propertyClrType is not null, "Couldn't find property"); + } + + innerTypes.Add((typeName, inner, propertyClrType)); + } + + string interfaceName = $"IFlatBufferUnion<{string.Join(", ", innerTypes.Select(x => x.resolvedType))}>"; + + writer.AppendSummaryComment(this.union.Documentation); + writer.AppendLine("[System.Runtime.CompilerServices.CompilerGenerated]"); + writer.AppendLine($"public partial class {this.Name} : object, {interfaceName}, IPoolableObject"); + using (writer.WithBlock()) + { + // Generate an internal type enum. + writer.AppendLine("public enum ItemKind : byte"); + using (writer.WithBlock()) + { + foreach (var item in this.union.Values) + { + writer.AppendLine($"{item.Value.Key} = {item.Value.Value},"); + } + } + + writer.AppendLine(); + writer.AppendLine("protected int discriminator;"); + + writer.AppendLine(); + writer.AppendLine("public ItemKind Kind => (ItemKind)this.Discriminator;"); + + writer.AppendLine(); + writer.AppendLine("public byte Discriminator => (byte)this.discriminator;"); + + foreach (var item in innerTypes) + { + this.WriteConstructor(writer, item.resolvedType, item.value, item.propertyType); + this.AddUnionMember(writer, item.resolvedType, item.value, item.propertyType, context); + } + + this.WriteDefaultConstructor(writer); + this.WriteReturnToPool(writer); + this.WriteAcceptMethod(writer, innerTypes); + } + } + + private void AddUnionMember(CodeWriter writer, string resolvedType, EnumVal value, Type? propertyClrType, CompileContext context) + { + writer.AppendLine(); + writer.AppendLine($"private {resolvedType}{(propertyClrType?.IsValueType == false ? "?" : string.Empty)} value_{value.Value};"); + + writer.AppendLine(); + writer.AppendLine($"public {resolvedType} Item{value.Value}"); + using (writer.WithBlock()) + { + writer.AppendLine("get"); + using (writer.WithBlock()) + { + writer.AppendLine($"if (this.Discriminator != {value.Value})"); + using (writer.WithBlock()) + { + writer.AppendLine("throw new InvalidOperationException();"); + } + + writer.AppendLine($"return this.value_{value.Value}!;"); + } + + writer.AppendLine("protected set"); + using (writer.WithBlock()) + { + writer.AppendLine($"this.value_{value.Value} = value;"); + } + } + + string notNullWhen = string.Empty; + if (context.Options.NullableWarnings == true) + { + notNullWhen = $"[global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] "; + } + + writer.AppendLine(); + writer.AppendLine($"public bool TryGet({notNullWhen} out {resolvedType} value)"); + using (writer.WithBlock()) + { + writer.AppendLine($"if (this.Discriminator != {value.Value})"); + using (writer.WithBlock()) + { + writer.AppendLine("value = default;"); + writer.AppendLine("return false;"); + } + + writer.AppendLine($"value = this.value_{value.Value}!;"); + writer.AppendLine("return true;"); + } + } + + private void WriteReturnToPool(CodeWriter writer) + { + writer.AppendInheritDoc(); + writer.AppendLine($"public virtual void ReturnToPool(bool unsafeForce = false) {{ }}"); + } + + private void WriteAcceptMethod( + CodeWriter writer, + List<(string resolvedType, EnumVal value, Type? propertyType)> components) + { + string visitorBaseType = $"IFlatBufferUnionVisitor x.resolvedType))}>"; + + writer.AppendSummaryComment("A convenience interface for implementing a visitor."); + writer.AppendLine($"public interface Visitor : {visitorBaseType} {{ }}"); + + writer.AppendSummaryComment("Accepts a visitor into this FlatBufferUnion."); + writer.AppendLine($"public TReturn Accept(TVisitor visitor)"); + writer.AppendLine($" where TVisitor : {visitorBaseType}"); + using (writer.WithBlock()) + { + writer.AppendLine("var disc = this.Discriminator;"); + writer.AppendLine("switch (disc)"); + using (writer.WithBlock()) + { + foreach (var item in components) + { + long index = item.value.Value; + writer.AppendLine($"case {index}: return visitor.Visit(this.value_{item.value.Value});"); + } + + writer.AppendLine($"default: throw new {typeof(InvalidOperationException).GetCompilableTypeName()}(\"Unexpected discriminator: \" + disc);"); + } + } + } + + private void WriteConstructor(CodeWriter writer, string resolvedType, EnumVal unionValue, Type? propertyType) + { + writer.AppendLine($"public {this.Name}({resolvedType} value)"); + using (writer.WithBlock()) + { + if (propertyType?.IsValueType == false) + { + writer.AppendLine("if (value is null)"); + using (writer.WithBlock()) + { + writer.AppendLine("throw new ArgumentNullException(nameof(value));"); + } + } + + writer.AppendLine($"this.discriminator = {unionValue.Value};"); + writer.AppendLine($"this.Item{unionValue.Value} = value;"); + } + } + + private void WriteDefaultConstructor(CodeWriter writer) + { + writer.AppendLine($"protected {this.Name}()"); + using (writer.WithBlock()) + { + } + } +} diff --git a/src/FlatSharp.Compiler/SchemaModel/TableSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/TableSchemaModel.cs index 824eb3a..b5dcbda 100644 --- a/src/FlatSharp.Compiler/SchemaModel/TableSchemaModel.cs +++ b/src/FlatSharp.Compiler/SchemaModel/TableSchemaModel.cs @@ -99,6 +99,12 @@ public class TableSchemaModel : BaseReferenceTypeSchemaModel using (writer.IncreaseIndent()) { writer.AppendLine(": object"); + + if (context.Options.GeneratePoolableObjects == true) + { + writer.AppendLine(", IPoolableObject"); + } + if (this.Attributes.DeserializationOption is not null && context.CompilePass >= CodeWritingPass.SerializerAndRpcGeneration) { writer.AppendLine($", {nameof(IFlatBufferSerializable)}<{this.FullName}>"); diff --git a/src/FlatSharp.Compiler/SchemaModel/UnionSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs similarity index 94% rename from src/FlatSharp.Compiler/SchemaModel/UnionSchemaModel.cs rename to src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs index dff76e9..207a380 100644 --- a/src/FlatSharp.Compiler/SchemaModel/UnionSchemaModel.cs +++ b/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs @@ -19,14 +19,15 @@ using FlatSharp.Compiler.Schema; using FlatSharp.CodeGen; using System.Runtime.InteropServices; using System.IO; +using System.Reflection.Metadata; namespace FlatSharp.Compiler.SchemaModel; -public class UnionSchemaModel : BaseSchemaModel +public class ValueUnionSchemaModel : BaseSchemaModel { private readonly FlatBufferEnum union; - private UnionSchemaModel(Schema.Schema schema, FlatBufferEnum union) : base(schema, union.Name, new FlatSharpAttributes(union.Attributes)) + private ValueUnionSchemaModel(Schema.Schema schema, FlatBufferEnum union) : base(schema, union.Name, new FlatSharpAttributes(union.Attributes)) { FlatSharpInternal.Assert(union.UnderlyingType.BaseType == BaseType.UType, "Expecting utype"); @@ -36,7 +37,11 @@ public class UnionSchemaModel : BaseSchemaModel this.AttributeValidator.UnsafeUnionValidator = b => AttributeValidationResult.Valid; } - public static bool TryCreate(Schema.Schema schema, FlatBufferEnum union, [NotNullWhen(true)] out UnionSchemaModel? model) + public static bool TryCreate( + Schema.Schema schema, + FlatBufferEnum union, + CompilerOptions context, + [NotNullWhen(true)] out BaseSchemaModel? model) { if (union.UnderlyingType.BaseType != BaseType.UType) { @@ -44,7 +49,15 @@ public class UnionSchemaModel : BaseSchemaModel return false; } - model = new UnionSchemaModel(schema, union); + if (context.GeneratePoolableObjects == true) + { + model = new ReferenceUnionSchemaModel(schema, union); + } + else + { + model = new ValueUnionSchemaModel(schema, union); + } + return true; } diff --git a/src/FlatSharp.Runtime/FlatBufferDeserializationContext.cs b/src/FlatSharp.Runtime/FlatBufferDeserializationContext.cs index 61c64db..7474cfc 100644 --- a/src/FlatSharp.Runtime/FlatBufferDeserializationContext.cs +++ b/src/FlatSharp.Runtime/FlatBufferDeserializationContext.cs @@ -20,7 +20,7 @@ namespace FlatSharp; /// A context that FlatSharp-deserialized classes will pass to their parent /// object on construction, if the parent object defines a constructor that accepts this object. /// -public class FlatBufferDeserializationContext +public struct FlatBufferDeserializationContext { /// /// Initializes a new FlatSharpConstructorContext with the given deserialization option. diff --git a/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs b/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs index ac7092a..a6a67e3 100644 --- a/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs +++ b/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs @@ -114,22 +114,36 @@ internal class GeneratedSerializerWrapper : ISerializer, ISerializer where var parseArgs = new GeneratedSerializerParseArguments(0, this.remainingDepthLimit); var inner = this.innerSerializer; + T item; + switch (option ?? this.option) { case FlatBufferDeserializationOption.Lazy: - return buffer.InvokeLazyParse(inner, in parseArgs); + item = buffer.InvokeLazyParse(inner, in parseArgs); + break; case FlatBufferDeserializationOption.Greedy: - return buffer.InvokeGreedyParse(inner, in parseArgs); + item = buffer.InvokeGreedyParse(inner, in parseArgs); + break; case FlatBufferDeserializationOption.GreedyMutable: - return buffer.InvokeGreedyMutableParse(inner, in parseArgs); + item = buffer.InvokeGreedyMutableParse(inner, in parseArgs); + break; case FlatBufferDeserializationOption.Progressive: - return buffer.InvokeProgressiveParse(inner, in parseArgs); + item = buffer.InvokeProgressiveParse(inner, in parseArgs); + break; + + default: + throw new InvalidOperationException("Unexpected deserialization mode: " + this.option); } - throw new InvalidOperationException("Unexpected deserialization mode: " + this.option); + if (item is IPoolableObjectDebug deserializedObject) + { + deserializedObject.IsRoot = true; + } + + return item; } object ISerializer.Parse(TInputBuffer buffer, FlatBufferDeserializationOption? option) => this.Parse(buffer, option); diff --git a/src/FlatSharp.Runtime/IObjectPool.cs b/src/FlatSharp.Runtime/IObjectPool.cs new file mode 100644 index 0000000..277ed2f --- /dev/null +++ b/src/FlatSharp.Runtime/IObjectPool.cs @@ -0,0 +1,67 @@ +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace FlatSharp +{ +#if DEBUG + /// + /// Defines an object pool that FlatSharp may use to reduce allocations. + /// + public interface IObjectPool + { + /// + /// Attempts to get an item from the pool. + /// + /// The type of item. + /// The value, as an output parameter. + /// True if the value was returned. False otherwise. + bool TryGet([NotNullWhen(true)] out T? value); + + /// + /// Returns an item to the pool. FlatSharp users should never use this method directly. + /// + /// The type of item. + /// The item to return. + void Return(T item); + } +#endif + + /// + /// A FlatSharp poolable object. + /// + public interface IPoolableObject + { + /// + /// Attempts to return this object to the pool. + /// + /// Force this back to the pool, regardless of internal consistency rules. + void ReturnToPool(bool unsafeForce = false); + } +} + +namespace FlatSharp.Internal +{ + /// + /// Debug information for poolable objects. + /// + public interface IPoolableObjectDebug + { + /// + /// Indicates if this object is the root of the parse tree. + /// + bool IsRoot { get; set; } + } +} \ No newline at end of file diff --git a/src/FlatSharp.Runtime/ObjectPool.cs b/src/FlatSharp.Runtime/ObjectPool.cs new file mode 100644 index 0000000..ed3f589 --- /dev/null +++ b/src/FlatSharp.Runtime/ObjectPool.cs @@ -0,0 +1,155 @@ +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Concurrent; +using System.Threading; + +namespace FlatSharp.Internal +{ + /// + /// Internal disposal state. + /// + public static class ObjectPoolDisposalState + { + /// + /// Extension method to indicate if we should always dispose or not. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool ShouldReturnToPool(this FlatBufferDeserializationOption option, bool force) + { + return // ObjectPool.Instance is not null && + (force || option == FlatBufferDeserializationOption.Lazy); + } + } +} + +namespace FlatSharp +{ + /// + /// A static singleton that defines the FlatSharp object pool to use. + /// + public static class ObjectPool + { + /// + /// Gets or sets the soft limit on the maximum number of objects FlatSharp should retain. Note that this + /// limit applies per-type and is a soft limit, meaning it may be exceeded temporarily. + /// + public static int MaxToRetain { get; set; } + +#if DEBUG + public static IObjectPool? Instance + { + get; + set; + } + + /// + /// Attempts to get an instance of from the pool. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryGet([NotNullWhen(true)] out T? value) + { + value = default; + return Instance?.TryGet(out value) == true; + } + + /// + /// Returns the given item to the pool. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Return(T item) + { + Instance?.Return(item); + } +#else + + /// + /// Attempts to get an instance of from the pool. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryGet([NotNullWhen(true)] out T? value) + { + return DefaultObjectPool.TryGet(out value); + } + + /// + /// Returns the given item to the pool. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Return(T item) + { + DefaultObjectPool.Return(item); + } +#endif + } + + /// + /// A default implementation of the interface. + /// + [ExcludeFromCodeCoverage] + public static class DefaultObjectPool + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Return(T item) + { + Pool.Return(item, ObjectPool.MaxToRetain); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryGet([NotNullWhen(true)] out T? value) + { + return Pool.TryGet(out value); + } + + private static class Pool + { + private static readonly ConcurrentQueue pool = new(); + + /// + /// ConcurrentQueue's Count property is quite slow. FastCount just uses interlocked operations. + /// + private static int FastCount = 0; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryGet([NotNullWhen(true)] out T? item) + { + item = default; + + if (FastCount > 0) + { + if (pool.TryDequeue(out item)) + { + Interlocked.Decrement(ref FastCount); + Debug.Assert(item is not null); + return true; + } + } + + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Return(T item, int limit) + { + if (FastCount < limit) + { + pool.Enqueue(item); + Interlocked.Increment(ref FastCount); + } + } + } + } +} \ No newline at end of file diff --git a/src/FlatSharp.Runtime/Properties/AssemblyInfo.cs b/src/FlatSharp.Runtime/Properties/AssemblyInfo.cs index ba508bb..ee458b9 100644 --- a/src/FlatSharp.Runtime/Properties/AssemblyInfo.cs +++ b/src/FlatSharp.Runtime/Properties/AssemblyInfo.cs @@ -21,6 +21,7 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("FlatSharp, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] [assembly: InternalsVisibleTo("FlatSharpTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] [assembly: InternalsVisibleTo("FlatSharpEndToEndTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] +[assembly: InternalsVisibleTo("PoolingTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] [assembly: InternalsVisibleTo("FlatSharpCompilerTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] [assembly: InternalsVisibleTo("FlatSharp.Compiler, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] [assembly: InternalsVisibleTo("ExperimentalBenchmark, PublicKey=0024000004800000940000000602000000240000525341310004000001000100898185ce69dca04430ab296e094cd7eb6c66f5a3cfb0631ef64586fa183f0cb5ca64c47539a3a3c6351a9cf8d976a8d94350af430d5adc10536b3904cc1d6ecaaf3d0cb708aa318c559625f05d3b2d89da1c2bb323bb40e36dcf9245f21c3a4b6793c56ffface5e6e18290afb13c7eac1ea9c7a0c22f289c622bfa7b247d81a2")] @@ -32,6 +33,7 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("FlatSharpTests")] [assembly: InternalsVisibleTo("FlatSharpCompilerTests")] [assembly: InternalsVisibleTo("FlatSharpEndToEndTests")] +[assembly: InternalsVisibleTo("PoolingTests")] [assembly: InternalsVisibleTo("FlatSharp.Compiler")] [assembly: InternalsVisibleTo("ExperimentalBenchmark")] [assembly: InternalsVisibleTo("Benchmark")] diff --git a/src/FlatSharp.Runtime/VectorCloneHelpers.cs b/src/FlatSharp.Runtime/VectorCloneHelpers.cs index c105969..2a2f520 100644 --- a/src/FlatSharp.Runtime/VectorCloneHelpers.cs +++ b/src/FlatSharp.Runtime/VectorCloneHelpers.cs @@ -46,7 +46,7 @@ public static class VectorCloneHelpers [return: NotNullIfNotNull("source")] public static IList? CloneVectorOfUnion(IList? source, CloneCallback cloneItem) - where T : struct, IFlatBufferUnion + where T : IFlatBufferUnion { if (source is null) { @@ -86,7 +86,7 @@ public static class VectorCloneHelpers [return: NotNullIfNotNull("source")] public static IReadOnlyList? CloneVectorOfUnion(IReadOnlyList? source, CloneCallback cloneItem) - where T : struct, IFlatBufferUnion + where T : IFlatBufferUnion { if (source is null) { diff --git a/src/FlatSharp.Runtime/Vectors/FlatBufferIndexedVector.cs b/src/FlatSharp.Runtime/Vectors/FlatBufferIndexedVector.cs index faed718..2956cb3 100644 --- a/src/FlatSharp.Runtime/Vectors/FlatBufferIndexedVector.cs +++ b/src/FlatSharp.Runtime/Vectors/FlatBufferIndexedVector.cs @@ -14,25 +14,49 @@ * limitations under the License. */ +using System.Threading; + namespace FlatSharp.Internal; /// /// An implementation of IIndexedVector for use after deserializing an object. This class is not intended to be used /// directly -- only from code generated by FlatSharp. /// -public class FlatBufferIndexedVector : IIndexedVector +public class FlatBufferIndexedVector + : IIndexedVector + where TValue : class, ISortableTable where TKey : notnull where TInputBuffer : IInputBuffer where TVectorItemAccessor : IVectorItemAccessor { - private readonly FlatBufferVectorBase vector; + private FlatBufferDeserializationOption deserializationOption; + private FlatBufferVectorBase vector; - public FlatBufferIndexedVector(FlatBufferVectorBase vector) +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private FlatBufferIndexedVector() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. { + } + + private void Initialize(FlatBufferVectorBase vector) + { + this.deserializationOption = vector.DeserializationOption; this.vector = vector; } + public static FlatBufferIndexedVector GetOrCreate( + FlatBufferVectorBase vector) + { + if (!ObjectPool.TryGet>(out var item)) + { + item = new(); + } + + item.Initialize(vector); + return item; + } + public TValue this[TKey key] { get @@ -103,4 +127,18 @@ public class FlatBufferIndexedVector /// An implementation that loads data progressively. /// -public sealed class FlatBufferProgressiveIndexedVector : IIndexedVector +public sealed class FlatBufferProgressiveIndexedVector + : IIndexedVector + where TValue : class, ISortableTable where TKey : notnull where TInputBuffer : IInputBuffer where TVectorItemAccessor : IVectorItemAccessor { - private readonly Dictionary backingDictionary; - private readonly FlatBufferProgressiveVector backingVector; + private readonly Dictionary backingDictionary = new(); + private FlatBufferProgressiveVector backingVector; - public FlatBufferProgressiveIndexedVector(FlatBufferVectorBase items) + private FlatBufferDeserializationOption deserializationOption; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private FlatBufferProgressiveIndexedVector() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. { - this.backingDictionary = new Dictionary(); - this.backingVector = new FlatBufferProgressiveVector(items); + } + + private void Initialize(FlatBufferVectorBase items) + { + this.deserializationOption = items.DeserializationOption; + this.backingVector = FlatBufferProgressiveVector.GetOrCreate(items); + } + + public static FlatBufferProgressiveIndexedVector GetOrCreate(FlatBufferVectorBase items) + { + if (!ObjectPool.TryGet>(out var item)) + { + item = new(); + } + + item.Initialize(items); + return item; } /// @@ -135,4 +158,19 @@ public sealed class FlatBufferProgressiveIndexedVector /// A vector implementation that is filled on demand. Optimized /// for data locality, random access, and reasonably low memory overhead. /// -public sealed class FlatBufferProgressiveVector : IList, IReadOnlyList +public sealed class FlatBufferProgressiveVector + : IList + , IReadOnlyList + , IPoolableObject + where T : notnull where TInputBuffer : IInputBuffer where TVectorItemAccessor : IVectorItemAccessor @@ -31,15 +38,34 @@ public sealed class FlatBufferProgressiveVector innerVector; + private T?[]?[] items; + private FlatBufferVectorBase innerVector; - public FlatBufferProgressiveVector( - FlatBufferVectorBase innerVector) +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private FlatBufferProgressiveVector() + { + } +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + private void Initialize(FlatBufferVectorBase innerVector) { this.Count = innerVector.Count; this.innerVector = innerVector; - this.items = new T[(innerVector.Count / ChunkSize) + 1][]; + this.DeserializationOption = innerVector.DeserializationOption; + int minLength = (int)(innerVector.Count / ChunkSize + 1); + this.items = ArrayPool.Shared.Rent(minLength); + } + + public static FlatBufferProgressiveVector GetOrCreate(FlatBufferVectorBase innerVector) + { + if (!ObjectPool.TryGet>(out var item)) + { + item = new FlatBufferProgressiveVector(); + } + + item.Initialize(innerVector); + + return item; } /// @@ -86,10 +112,12 @@ public sealed class FlatBufferProgressiveVector true; + internal FlatBufferDeserializationOption DeserializationOption { get; private set; } + public void Add(T item) { throw new NotMutableException("FlatBufferVector does not allow adding items."); @@ -182,7 +210,7 @@ public sealed class FlatBufferProgressiveVector.Shared.Rent((int)ChunkSize); items[rowIndex] = row; // For value types -- we can't rely on null to tell @@ -199,4 +227,48 @@ public sealed class FlatBufferProgressiveVector.Shared.Return(block, true); + items[i] = null; + } + } + + ArrayPool.Shared.Return(items); + + this.Count = 0; + this.innerVector.ReturnToPool(true); + this.innerVector = default!; + + ObjectPool.Return(this); + } + } } diff --git a/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs b/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs index d558593..a088029 100644 --- a/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs +++ b/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs @@ -14,31 +14,62 @@ * limitations under the License. */ +using System.Threading; + namespace FlatSharp.Internal; /// /// A base flat buffer vector, common to standard vectors and unions. /// public sealed class FlatBufferVectorBase - : IList, IReadOnlyList, IFlatBufferDeserializedVector + : IList + , IReadOnlyList + , IFlatBufferDeserializedVector + , IPoolableObject + where TInputBuffer : IInputBuffer where TItemAccessor : IVectorItemAccessor { - private readonly TInputBuffer memory; - private readonly TableFieldContext fieldContext; - private readonly short remainingDepth; - private readonly TItemAccessor itemAccessor; + private TInputBuffer memory; + private TableFieldContext fieldContext; + private short remainingDepth; + private TItemAccessor itemAccessor; - public FlatBufferVectorBase( +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private FlatBufferVectorBase() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + { + } + + private void Initialize( TInputBuffer memory, TItemAccessor itemAccessor, short remainingDepth, - TableFieldContext fieldContext) + TableFieldContext fieldContext, + FlatBufferDeserializationOption option) { this.memory = memory; this.remainingDepth = remainingDepth; - this.fieldContext = fieldContext; this.itemAccessor = itemAccessor; + this.DeserializationOption = option; + this.fieldContext = fieldContext; + } + + public static FlatBufferVectorBase GetOrCreate( + TInputBuffer memory, + TItemAccessor itemAccessor, + short remainingDepth, + TableFieldContext fieldContext, + FlatBufferDeserializationOption option) + { + if (!ObjectPool.TryGet>(out var item)) + { + item = new FlatBufferVectorBase(); + } + + item.Initialize(memory, itemAccessor, remainingDepth, fieldContext, option); + + return item; } /// @@ -59,6 +90,8 @@ public sealed class FlatBufferVectorBase } } + public FlatBufferDeserializationOption DeserializationOption { get; private set; } + public int Count => this.itemAccessor.Count; public bool IsReadOnly => true; @@ -153,42 +186,6 @@ public sealed class FlatBufferVectorBase return -1; } - public List FlatBufferVectorToList() - { - int count = this.Count; - var list = new List(count); - - var context = this.fieldContext; - var remainingDepth = this.remainingDepth; - var buffer = this.memory; - - for (int i = 0; i < count; ++i) - { - this.itemAccessor.ParseItem(i, buffer, remainingDepth, context, out T item); - list.Add(item); - } - - return list; - } - - public ImmutableList ToImmutableList() - { - int count = this.Count; - var list = new T[count]; - - var context = this.fieldContext; - var remainingDepth = this.remainingDepth; - var buffer = this.memory; - - for (int i = 0; i < count; ++i) - { - this.itemAccessor.ParseItem(i, buffer, remainingDepth, context, out T item); - list[i] = item; - } - - return new ImmutableList(list); - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] private void CheckIndex(int index) { @@ -232,4 +229,19 @@ public sealed class FlatBufferVectorBase int IFlatBufferDeserializedVector.OffsetOf(int index) => this.itemAccessor.OffsetOf(index); object IFlatBufferDeserializedVector.ItemAt(int index) => this[index]!; + + public void ReturnToPool(bool force = false) + { + if (this.DeserializationOption.ShouldReturnToPool(force)) + { + var context = Interlocked.Exchange(ref this.fieldContext!, null); + if (context is not null) + { + this.memory = default!; + this.remainingDepth = default; + this.itemAccessor = default!; + ObjectPool.Return(this); + } + } + } } diff --git a/src/FlatSharp.Runtime/Vectors/GreedyIndexedVector.cs b/src/FlatSharp.Runtime/Vectors/GreedyIndexedVector.cs new file mode 100644 index 0000000..7527ab9 --- /dev/null +++ b/src/FlatSharp.Runtime/Vectors/GreedyIndexedVector.cs @@ -0,0 +1,208 @@ +/* + * Copyright 2020 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Threading; + +namespace FlatSharp.Internal; + +public sealed class GreedyIndexedVector : IIndexedVector + where TValue : class, ISortableTable + where TKey : notnull +{ + private int alive; + private readonly Dictionary backingDictionary; + private bool mutable; + + private GreedyIndexedVector() + { + this.backingDictionary = new Dictionary(); + this.mutable = true; + } + + public static GreedyIndexedVector GetOrCreate( + FlatBufferVectorBase backing, + bool mutable) + where TInputBuffer : IInputBuffer + where TItemAccessor : IVectorItemAccessor + { + if (!ObjectPool.TryGet(out GreedyIndexedVector? vector)) + { + vector = new(); + } + + vector.mutable = mutable; + vector.alive = 1; + + var dict = vector.backingDictionary; + +#if !NETSTANDARD2_0 + dict.EnsureCapacity(backing.Count); +#endif + + foreach (TValue value in backing) + { + TKey key = SortedVectorHelpers.KeyLookup.KeyGetter(value); + if (dict.TryGetValue(key, out var existingValue)) + { + (existingValue as IPoolableObject)?.ReturnToPool(); + } + + dict[key] = value; + } + + // we don't need "backing" any longer + backing.ReturnToPool(true); + + return vector; + } + + /// + /// Gets the key from the given value. + /// + public static TKey GetKey(TValue value) => SortedVectorHelpers.KeyLookup.KeyGetter(value); + + /// + /// An indexer for getting values by their keys. + /// + public TValue this[TKey key] + { + get => this.backingDictionary[key]; + } + + /// + /// Indicates if this IndexedVector is read only. + /// + public bool IsReadOnly => !this.mutable; + + /// + /// Gets the count of items. + /// + public int Count => this.backingDictionary.Count; + + /// + /// Freezes an Indexed vector, preventing further modifications. + /// + public void Freeze() + { + this.mutable = false; + } + + /// + /// Returns true if the vector contains the given key. + /// + public bool ContainsKey(TKey key) => this.backingDictionary.ContainsKey(key); + + /// + /// Tries to get the given value from the backing dictionary. + /// + public bool TryGetValue(TKey key, [NotNullWhen(true)] out TValue? value) + { + return this.backingDictionary.TryGetValue(key, out value); + } + + /// + /// Gets the dictionary's enumerator. + /// + public IEnumerator> GetEnumerator() => this.backingDictionary.GetEnumerator(); + + /// + /// Gets a non-generic enumerator. + /// + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + + /// + /// Adds or replaces the item with the given key to the indexed vector. + /// + public void AddOrReplace(TValue value) + { + if (!this.mutable) + { + throw new NotMutableException(); + } + + this.backingDictionary[GetKey(value)] = value; + } + + /// + /// Attempts to add the value to the indexed vector, if a key does not already exist. + /// + public bool Add(TValue value) + { + if (!this.mutable) + { + throw new NotMutableException(); + } + + TKey key = GetKey(value); + +#if NETSTANDARD2_0 + var dictionary = this.backingDictionary; + if (dictionary.ContainsKey(key)) + { + return false; + } + + dictionary[key] = value; + return true; +#else + return this.backingDictionary.TryAdd(key, value); +#endif + } + + public void Clear() + { + if (!this.mutable) + { + throw new NotMutableException(); + } + + this.backingDictionary.Clear(); + } + + public bool Remove(TKey key) + { + if (!this.mutable) + { + throw new NotMutableException(); + } + + return this.backingDictionary.Remove(key); + } + + public void ReturnToPool(bool unsafeForce = false) + { + if (FlatBufferDeserializationOption.Greedy.ShouldReturnToPool(unsafeForce)) + { + if (Interlocked.Exchange(ref this.alive, 0) != 0) + { + var dict = this.backingDictionary; + + foreach (var item in dict) + { + if (item.Value is IPoolableObject obj) + { + obj.ReturnToPool(true); + } + } + + dict.Clear(); + this.mutable = false; + + ObjectPool.Return(this); + } + } + } +} diff --git a/src/FlatSharp.Runtime/Vectors/IIndexedVector.cs b/src/FlatSharp.Runtime/Vectors/IIndexedVector.cs index 9174396..d28d669 100644 --- a/src/FlatSharp.Runtime/Vectors/IIndexedVector.cs +++ b/src/FlatSharp.Runtime/Vectors/IIndexedVector.cs @@ -19,7 +19,7 @@ namespace FlatSharp; /// /// An indexed vector -- suitable for accessing values by their keys. /// -public interface IIndexedVector : IEnumerable> +public interface IIndexedVector : IEnumerable>, IPoolableObject where TValue : class where TKey : notnull { diff --git a/src/FlatSharp.Runtime/Vectors/ImmutableList.cs b/src/FlatSharp.Runtime/Vectors/ImmutableList.cs index 09f7ed8..25a1908 100644 --- a/src/FlatSharp.Runtime/Vectors/ImmutableList.cs +++ b/src/FlatSharp.Runtime/Vectors/ImmutableList.cs @@ -14,6 +14,10 @@ * limitations under the License. */ +using System.Buffers; +using System.Linq; +using System.Threading; + namespace FlatSharp.Internal; /// @@ -26,13 +30,49 @@ namespace FlatSharp.Internal; /// - Second, it does not reference internally, which means it is able to skip a level of virtual indirection /// by using directly. /// -public sealed class ImmutableList : IList, IReadOnlyList +public sealed class ImmutableList : IList, IReadOnlyList, IPoolableObject { - private readonly T[] list; + private List list; + private int isAlive; - public ImmutableList(T[] list) + public ImmutableList(IList template) { - this.list = list; + this.list = template.ToList(); + } + + public ImmutableList(int capacity) + { + this.list = new List(capacity); + } + + public static ImmutableList GetOrCreate(FlatBufferVectorBase vector) + where TInputBuffer : IInputBuffer + where TItemAccessor : IVectorItemAccessor + { + int count = vector.Count; + + if (ObjectPool.TryGet(out ImmutableList? list)) + { +#if NET6_0_OR_GREATER + list.list.EnsureCapacity(count); +#endif + } + else + { + list = new(count); + } + + list.isAlive = 1; + + for (int i = 0; i < count; ++i) + { + list.list.Add(vector[i]); + } + + // We've copied our stuff -- send the base vector back to where it came from! + vector.ReturnToPool(true); + + return list; } public T this[int index] @@ -41,7 +81,7 @@ public sealed class ImmutableList : IList, IReadOnlyList set => throw new NotMutableException(); } - public int Count => this.list.Length; + public int Count => this.list.Count; public bool IsReadOnly => true; @@ -55,25 +95,16 @@ public sealed class ImmutableList : IList, IReadOnlyList throw new NotMutableException(); } - public bool Contains(T item) - { - return Array.IndexOf(this.list, item) >= 0; - } + public bool Contains(T item) => this.list.Contains(item); - public void CopyTo(T[] array, int arrayIndex) - { - this.list.CopyTo(array, arrayIndex); - } + public void CopyTo(T[] array, int arrayIndex) => this.list.CopyTo(array, arrayIndex); public IEnumerator GetEnumerator() { - return ((IList)this.list).GetEnumerator(); + return this.list.GetEnumerator(); } - public int IndexOf(T item) - { - return Array.IndexOf(this.list, item); - } + public int IndexOf(T item) => this.list.IndexOf(item); public void Insert(int index, T item) { @@ -90,6 +121,29 @@ public sealed class ImmutableList : IList, IReadOnlyList throw new NotMutableException(); } + public void ReturnToPool(bool force) + { + if (force) + { + if (Interlocked.Exchange(ref this.isAlive, 0) != 0) + { + if (!typeof(T).IsValueType) + { + foreach (var item in this.list) + { + if (item is IPoolableObject poolable) + { + poolable.ReturnToPool(true); + } + } + } + + this.list.Clear(); + ObjectPool.Return(this); + } + } + } + IEnumerator IEnumerable.GetEnumerator() { return this.GetEnumerator(); diff --git a/src/FlatSharp.Runtime/Vectors/IndexedVector.cs b/src/FlatSharp.Runtime/Vectors/IndexedVector.cs index cb9a469..64db326 100644 --- a/src/FlatSharp.Runtime/Vectors/IndexedVector.cs +++ b/src/FlatSharp.Runtime/Vectors/IndexedVector.cs @@ -166,4 +166,9 @@ public sealed class IndexedVector : IIndexedVector return this.backingDictionary.Remove(key); } + + [ExcludeFromCodeCoverage] + public void ReturnToPool(bool unsafeForce = false) + { + } } diff --git a/src/FlatSharp.Runtime/Vectors/PoolableList.cs b/src/FlatSharp.Runtime/Vectors/PoolableList.cs new file mode 100644 index 0000000..fc85c83 --- /dev/null +++ b/src/FlatSharp.Runtime/Vectors/PoolableList.cs @@ -0,0 +1,126 @@ +/* + * Copyright 2022 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using System.Threading; + +namespace FlatSharp.Internal; + +/// +/// A wrapper around a List{T}. +/// +public sealed class PoolableList : IList, IReadOnlyList, IPoolableObject +{ + private List list; + private int isAlive; + + public PoolableList(IList template) + { + this.list = template.ToList(); + } + + public PoolableList(int capacity) + { + this.list = new(capacity); + } + + public static PoolableList GetOrCreate(FlatBufferVectorBase item) + where TInputBuffer : IInputBuffer + where TItemAccessor : IVectorItemAccessor + { + int count = item.Count; + + if (ObjectPool.TryGet(out PoolableList? list)) + { +#if NET6_0_OR_GREATER + list.list.EnsureCapacity(count); +#endif + } + else + { + list = new(count); + } + + list.isAlive = 1; + + for (int i = 0; i < count; ++i) + { + list.Add(item[i]); + } + + // We've copied our stuff -- send the base vector back to where it came from! + item.ReturnToPool(true); + + return list; + } + + public T this[int index] + { + get => this.list[index]; + set => this.list[index] = value; + } + + public int Count => this.list.Count; + + public bool IsReadOnly => false; + + public void Add(T item) => this.list.Add(item); + + public void Clear() => this.list.Clear(); + + public bool Contains(T item) => this.list.Contains(item); + + public void CopyTo(T[] array, int arrayIndex) => this.list.CopyTo(array, arrayIndex); + + public IEnumerator GetEnumerator() => this.list.GetEnumerator(); + + public int IndexOf(T item) => this.list.IndexOf(item); + + public void Insert(int index, T item) => this.list.Insert(index, item); + + public bool Remove(T item) => this.list.Remove(item); + + public void RemoveAt(int index) => this.list.RemoveAt(index); + + public void ReturnToPool(bool force) + { + if (force) + { + var isAlive = Interlocked.Exchange(ref this.isAlive, 0); + if (isAlive != 0) + { + if (!typeof(T).IsValueType) + { + foreach (var item in this.list) + { + if (item is IPoolableObject poolable) + { + poolable.ReturnToPool(true); + } + } + } + + this.list.Clear(); + ObjectPool.Return(this); + } + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } +} \ No newline at end of file diff --git a/src/FlatSharp.sln b/src/FlatSharp.sln index 3fcbffc..c7cd807 100644 --- a/src/FlatSharp.sln +++ b/src/FlatSharp.sln @@ -33,6 +33,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microbench.Current", "Bench EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microbench.6.3.3", "Benchmarks\MicroBench.6.3.3\Microbench.6.3.3.csproj", "{407BD242-0902-4850-9F53-17DF48F3DD6F}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FlatSharpPoolableEndToEndTests", "Tests\FlatSharpPoolableEndToEndTests\FlatSharpPoolableEndToEndTests.csproj", "{7E55A248-4BBD-48AD-B27D-A7E0E705DC89}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -87,6 +89,10 @@ Global {407BD242-0902-4850-9F53-17DF48F3DD6F}.Debug|Any CPU.Build.0 = Debug|Any CPU {407BD242-0902-4850-9F53-17DF48F3DD6F}.Release|Any CPU.ActiveCfg = Release|Any CPU {407BD242-0902-4850-9F53-17DF48F3DD6F}.Release|Any CPU.Build.0 = Release|Any CPU + {7E55A248-4BBD-48AD-B27D-A7E0E705DC89}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7E55A248-4BBD-48AD-B27D-A7E0E705DC89}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7E55A248-4BBD-48AD-B27D-A7E0E705DC89}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7E55A248-4BBD-48AD-B27D-A7E0E705DC89}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -101,6 +107,7 @@ Global {927A2D53-B127-4297-B17A-4E4EEE7ACD0D} = {83478353-8C5A-41C2-84C2-F79488B43CB0} {5A173CC8-53B4-445B-9AA2-68A61A4BE449} = {927A2D53-B127-4297-B17A-4E4EEE7ACD0D} {407BD242-0902-4850-9F53-17DF48F3DD6F} = {927A2D53-B127-4297-B17A-4E4EEE7ACD0D} + {7E55A248-4BBD-48AD-B27D-A7E0E705DC89} = {D1E90BAE-FC51-44DB-8215-1D9BB6059886} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {726A4C0E-5760-4B35-8C73-9AC21B2E4C8B} diff --git a/src/FlatSharp/Serialization/DeserializeClassDefinition.cs b/src/FlatSharp/Serialization/DeserializeClassDefinition.cs index 74a8663..eeac2c6 100644 --- a/src/FlatSharp/Serialization/DeserializeClassDefinition.cs +++ b/src/FlatSharp/Serialization/DeserializeClassDefinition.cs @@ -15,6 +15,7 @@ */ using FlatSharp.TypeModel; +using System.Linq; namespace FlatSharp.CodeGen; @@ -24,6 +25,9 @@ internal class DeserializeClassDefinition protected const string OffsetVariableName = "__offset"; protected const string VTableVariableName = "__vtable"; protected const string RemainingDepthVariableName = "__remainingDepth"; + protected const string IsAliveVariableName = "__alive"; + protected const string IsRootVariableName = "__isRoot"; + protected const string CtorContextVariableName = "__CtorContext"; protected readonly ITypeModel typeModel; protected readonly FlatBufferSerializerOptions options; @@ -31,6 +35,7 @@ internal class DeserializeClassDefinition protected readonly List propertyOverrides = new(); protected readonly List initializeStatements = new(); protected readonly List readMethods = new(); + protected readonly List itemModels = new(); // Maps field name -> field initializer. protected readonly Dictionary instanceFieldDefinitions = new(); @@ -54,6 +59,13 @@ internal class DeserializeClassDefinition this.onDeserializeMethod = onDeserializeMethod; this.vtableAccessor = "default"; + this.instanceFieldDefinitions[IsRootVariableName] = $"private bool {IsRootVariableName};"; + + if (typeof(IPoolableObject).IsAssignableFrom(this.typeModel.ClrType)) + { + this.instanceFieldDefinitions[IsAliveVariableName] = $"private int {IsAliveVariableName};"; + this.initializeStatements.Add($"this.{IsAliveVariableName} = 1;"); + } if (this.options.GreedyDeserialize) { @@ -107,6 +119,8 @@ internal class DeserializeClassDefinition ItemMemberModel itemModel, ParserCodeGenContext context) { + this.itemModels.Add(itemModel); + this.AddFieldDefinitions(itemModel); this.AddPropertyDefinitions(itemModel); this.AddCtorStatements(itemModel); @@ -260,13 +274,13 @@ internal class DeserializeClassDefinition string onDeserializedStatement = string.Empty; if (this.onDeserializeMethod is not null) { - onDeserializedStatement = $"base.{this.onDeserializeMethod.Name}(__CtorContext);"; + onDeserializedStatement = $"base.{this.onDeserializeMethod.Name}({CtorContextVariableName});"; } string baseParams = string.Empty; if (ctor.GetParameters().Length != 0) { - baseParams = "__CtorContext"; + baseParams = CtorContextVariableName; } string interfaceGlobalName = typeof(IFlatBufferDeserializedObject).GetGlobalCompilableTypeName(); @@ -276,9 +290,12 @@ internal class DeserializeClassDefinition private sealed class {this.ClassName} : {typeModel.GetGlobalCompilableTypeName()} , {interfaceGlobalName} + , {typeof(IPoolableObject).GetGlobalCompilableTypeName()} + , {typeof(IPoolableObjectDebug).GetGlobalCompilableTypeName()} + where TInputBuffer : IInputBuffer {{ - private static readonly {typeof(FlatBufferDeserializationContext).GetGlobalCompilableTypeName()} __CtorContext + private static readonly {typeof(FlatBufferDeserializationContext).GetGlobalCompilableTypeName()} {CtorContextVariableName} = new {typeof(FlatBufferDeserializationContext).GetGlobalCompilableTypeName()}({typeof(FlatBufferDeserializationOption).GetGlobalCompilableTypeName()}.{options.DeserializationOption}); {string.Join("\r\n", this.staticFieldDefinitions.Values)} @@ -293,11 +310,20 @@ internal class DeserializeClassDefinition {this.GetCtorMethodDefinition(onDeserializedStatement, baseParams)} + {this.GetDisposeMethodBody()} + {typeof(Type).GetGlobalCompilableTypeName()} {interfaceGlobalName}.{nameof(IFlatBufferDeserializedObject.TableOrStructType)} => typeof({typeModel.GetCompilableTypeName()}); {typeof(FlatBufferDeserializationContext).GetGlobalCompilableTypeName()} {interfaceGlobalName}.{nameof(IFlatBufferDeserializedObject.DeserializationContext)} => __CtorContext; {typeof(IInputBuffer).GetGlobalCompilableTypeName()}? {interfaceGlobalName}.{nameof(IFlatBufferDeserializedObject.InputBuffer)} => {this.GetBufferReference()}; + bool {interfaceGlobalName}.{nameof(IFlatBufferDeserializedObject.CanSerializeWithMemoryCopy)} => {this.options.CanSerializeWithMemoryCopy.ToString().ToLowerInvariant()}; + bool {typeof(IPoolableObjectDebug).GetGlobalCompilableTypeName()}.IsRoot + {{ + get => this.{IsRootVariableName}; + set => this.{IsRootVariableName} = value; + }} + {string.Join("\r\n", this.propertyOverrides)} {string.Join("\r\n", this.readMethods)} }} @@ -368,10 +394,28 @@ internal class DeserializeClassDefinition protected virtual string GetGetOrCreateMethodBody() { - return $@" - var item = new {this.ClassName}(buffer, offset, remainingDepth); - return item; - "; + if (typeof(IPoolableObject).IsAssignableFrom(this.typeModel.ClrType)) + { + return $@" + {this.ClassName}? item; + + if (!{typeof(ObjectPool).GetGlobalCompilableTypeName()}.TryGet<{this.ClassName}>(out item)) + {{ + item = new {this.ClassName}(); + }} + + item.Initialize(buffer, offset, remainingDepth); + return item; + "; + } + else + { + return $@" + {this.ClassName}? item = new(); + item.Initialize(buffer, offset, remainingDepth); + return item; + "; + } } protected virtual string GetCtorMethodDefinition(string onDeserializedStatement, string baseCtorParams) @@ -382,12 +426,113 @@ internal class DeserializeClassDefinition [System.Diagnostics.CodeAnalysis.SetsRequiredMembers] #endif [{typeof(MethodImplAttribute).GetGlobalCompilableTypeName()}({typeof(MethodImplOptions).GetGlobalCompilableTypeName()}.AggressiveInlining)] - private {this.ClassName}(TInputBuffer buffer, int offset, short remainingDepth) : base({baseCtorParams}) - {{ + private {this.ClassName}() : base({baseCtorParams}) + {{ + }} +#pragma warning restore CS8618 + + [{typeof(MethodImplAttribute).GetGlobalCompilableTypeName()}({typeof(MethodImplOptions).GetGlobalCompilableTypeName()}.AggressiveInlining)] + private void Initialize(TInputBuffer buffer, int offset, short remainingDepth) + {{ {string.Join("\r\n", this.initializeStatements)} {onDeserializedStatement} }} -#pragma warning restore CS8618 + "; + } + + private string GetDisposeMethodBody() + { + if (!typeof(IPoolableObject).IsAssignableFrom(this.typeModel.ClrType)) + { + // not disposable. + return $@"public void ReturnToPool(bool unsafeForce = false) {{ }}"; + } + + // Lazy doesn't have to deal with this stuff. + IEnumerable disposeFields = Array.Empty(); + if (!this.options.Lazy) + { + disposeFields = this.itemModels + .Where( + x => !x.ItemTypeModel.ClrType.IsValueType) + .Select(x => (model: x, isPoolable: typeof(IPoolableObject).IsAssignableFrom(x.ItemTypeModel.ClrType), isInterface: x.ItemTypeModel.ClrType.IsInterface)) + .Where(x => x.isPoolable || x.isInterface) + .Select(x => + { + if (x.isPoolable) + { + return $@" + {{ + var item = this.{GetFieldName(x.model)}; + this.{GetFieldName(x.model)} = null!; + item?.ReturnToPool(true); + }} + "; + } + else + { + return $@" + {{ + var item = this.{GetFieldName(x.model)}; + this.{GetFieldName(x.model)} = null!; + (item as IPoolableObject)?.ReturnToPool(true); + }} + "; + } + }); + + // reset all fields to default. + disposeFields = disposeFields.Concat(this.itemModels + .Select(x => $"this.{GetFieldName(x)} = default({x.ItemTypeModel.GetGlobalCompilableTypeName()})!;")); + + // Reset all masks to default as well. + if (!this.options.GreedyDeserialize) + { + HashSet masks = new HashSet(this.itemModels.Select(GetHasValueFieldName)); + + disposeFields = disposeFields.Concat(masks + .Select(x => $"this.{x} = 0;")); + } + } + + if (!this.options.GreedyDeserialize) + { + disposeFields = disposeFields.Concat(new[] + { + $"this.{InputBufferVariableName} = default(TInputBuffer)!;", + $"this.{RemainingDepthVariableName} = -1;", + $"this.{OffsetVariableName} = -1;", + }); + + if (this.typeModel.SchemaType == FlatBufferSchemaType.Table) + { + disposeFields = disposeFields.Concat(new[] { $"this.{VTableVariableName} = default({this.vtableTypeName});" }); + } + } + + string fromRootCondition = $"if (unsafeForce || this.{IsRootVariableName})"; + if (this.options.Lazy) + { + fromRootCondition = string.Empty; + } + + return $@" + + public override void ReturnToPool(bool unsafeForce = false) + {{ + {{ + {fromRootCondition} + {{ + if (System.Threading.Interlocked.Exchange(ref this.{IsAliveVariableName}, 0) != 0) + {{ + {string.Join("\r\n", disposeFields)} + + this.{IsRootVariableName} = false; + {typeof(ObjectPool).GetGlobalCompilableTypeName()}.Return(this); + }} + }} + }} + }} "; } diff --git a/src/FlatSharp/TypeModel/ITypeModel.cs b/src/FlatSharp/TypeModel/ITypeModel.cs index 58c5753..e6004e9 100644 --- a/src/FlatSharp/TypeModel/ITypeModel.cs +++ b/src/FlatSharp/TypeModel/ITypeModel.cs @@ -114,7 +114,7 @@ public interface ITypeModel /// /// Indicates the constructor that subclasses should use. This constructor must have either 0 parameters or 1 parameter - /// that accepts an instance of . + /// that accepts an instance of . /// ConstructorInfo? PreferredSubclassConstructor { get; } diff --git a/src/FlatSharp/TypeModel/UnionTypeModel.cs b/src/FlatSharp/TypeModel/UnionTypeModel.cs index bd1edae..e04ab9b 100644 --- a/src/FlatSharp/TypeModel/UnionTypeModel.cs +++ b/src/FlatSharp/TypeModel/UnionTypeModel.cs @@ -16,6 +16,7 @@ using System.Collections.Immutable; using System.Linq; +using System.Threading; namespace FlatSharp.TypeModel; @@ -139,6 +140,9 @@ $@" public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext context) { List switchCases = new List(); + + (string? extraClass, string createNew) = GetUnionHelperClass(context); + for (int i = 0; i < this.UnionElementTypeModel.Length; ++i) { var unionMember = this.UnionElementTypeModel[i]; @@ -160,7 +164,7 @@ $@" $@" case {unionIndex}: {inlineAdjustment} - return new {this.GetGlobalCompilableTypeName()}({itemContext.GetParseInvocation(unionMember.ClrType)}); + return {createNew}({itemContext.GetParseInvocation(unionMember.ClrType)}); "; switchCases.Add(@case); } @@ -179,7 +183,94 @@ $@" }} "; - return new CodeGeneratedMethod(body); + return new CodeGeneratedMethod(body) { ClassDefinition = extraClass }; + } + + private (string? classDef, string createNewUnion) GetUnionHelperClass(ParserCodeGenContext context) + { + if (this.ClrType.IsValueType || !typeof(IPoolableObject).IsAssignableFrom(this.ClrType)) + { + // Nothing special for value-type or non-poolable unions. + return (null, $"new {this.GetGlobalCompilableTypeName()}"); + } + + string className = "unionReader_" + Guid.NewGuid().ToString("n"); + + List getOrCreates = new(); + List returnToPoolCases = new(); + + for (int i = 0; i < this.UnionElementTypeModel.Length; ++i) + { + int unionIndex = i + 1; // unions start at 1. + string itemType = this.UnionElementTypeModel[i].GetGlobalCompilableTypeName(); + + getOrCreates.Add($@" + public static {className} GetOrCreate({itemType} value) + {{ + if (!{typeof(ObjectPool).GetGlobalCompilableTypeName()}.{nameof(ObjectPool.TryGet)}<{className}>(out var union)) + {{ + union = new {className}(); + }} + + union.discriminator = {unionIndex}; + union.Item{unionIndex} = value; + union.isAlive = 1; + + return union; + }} + "); + + string recursiveReturn = string.Empty; + if (typeof(IPoolableObject).IsAssignableFrom(this.UnionElementTypeModel[i].ClrType)) + { + recursiveReturn = $"this.Item{unionIndex}?.ReturnToPool(true);"; + } + + returnToPoolCases.Add($@" + case {unionIndex}: + {{ + {recursiveReturn} + this.Item{unionIndex} = default({itemType})!; + }} + break; + "); + } + + string returnCondition = string.Empty; + if (!context.Options.Lazy) + { + returnCondition = "if (unsafeForce)"; + } + + // Reference type unions are much more special! + string extraClass = $@" + private sealed class {className} : {this.ClrType.GetGlobalCompilableTypeName()} + {{ + private int isAlive; + + {string.Join("\r\n", getOrCreates)} + + public override void ReturnToPool(bool unsafeForce = false) + {{ + {returnCondition} + {{ + int alive = {typeof(Interlocked).GetGlobalCompilableTypeName()}.Exchange(ref this.isAlive, 0); + if (alive > 0) + {{ + switch (base.discriminator) + {{ + {string.Join("\r\n", returnToPoolCases)} + }} + + base.discriminator = -1; + {typeof(ObjectPool).GetGlobalCompilableTypeName()}.{nameof(ObjectPool.Return)}(this); + }} + }} + }} + }} + "; + + return (extraClass, $"{className}.GetOrCreate"); } public override CodeGeneratedMethod CreateSerializeMethodBody(SerializationCodeGenContext context) diff --git a/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs b/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs index 09bf635..cf05d2f 100644 --- a/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs +++ b/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs @@ -111,29 +111,30 @@ public class IndexedVectorTypeModel : BaseVectorTypeModel string accessorClassName = $"{vectorClassName}<{context.InputBufferTypeName}>"; string createFlatBufferVector = - $@"new FlatBufferVectorBase<{this.ItemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {accessorClassName}> ( + $@"FlatBufferVectorBase<{this.ItemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {accessorClassName}>.GetOrCreate( {context.InputBufferVariableName}, new {accessorClassName}( {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}), {context.InputBufferVariableName}), {context.RemainingDepthVariableName}, - {context.TableFieldContextVariableName})"; + {context.TableFieldContextVariableName}, + {typeof(FlatBufferDeserializationOption).GetGlobalCompilableTypeName()}.{context.Options.DeserializationOption})"; string mutable = context.Options.GenerateMutableObjects.ToString().ToLowerInvariant(); if (context.Options.GreedyDeserialize) { // Eager indexed vector. - body = $@"return new IndexedVector<{keyTypeName}, {valueTypeName}>({createFlatBufferVector}, {mutable});"; + body = $@"return GreedyIndexedVector<{keyTypeName}, {valueTypeName}>.GetOrCreate<{context.InputBufferTypeName}, {accessorClassName}>({createFlatBufferVector}, {mutable});"; } else if (context.Options.Lazy) { // Lazy indexed vector. - body = $@"return new FlatBufferIndexedVector<{keyTypeName}, {valueTypeName}, {context.InputBufferTypeName}, {accessorClassName}>({createFlatBufferVector});"; + body = $@"return FlatBufferIndexedVector<{keyTypeName}, {valueTypeName}, {context.InputBufferTypeName}, {accessorClassName}>.GetOrCreate({createFlatBufferVector});"; } else { FlatSharpInternal.Assert(context.Options.Progressive, "expecting progressive"); - body = $@"return new FlatBufferProgressiveIndexedVector<{keyTypeName}, {valueTypeName}, {context.InputBufferTypeName}, {accessorClassName}>({createFlatBufferVector});"; + body = $@"return FlatBufferProgressiveIndexedVector<{keyTypeName}, {valueTypeName}, {context.InputBufferTypeName}, {accessorClassName}>.GetOrCreate({createFlatBufferVector});"; } return new CodeGeneratedMethod(body) { IsMethodInline = true, ClassDefinition = vectorClassDef }; diff --git a/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs b/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs index 466be6f..4b291e8 100644 --- a/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs +++ b/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs @@ -113,13 +113,14 @@ public class ListVectorTypeModel : BaseVectorTypeModel string accessorClassName = $"{vectorClassName}<{context.InputBufferTypeName}>"; string createFlatBufferVector = - $@"new FlatBufferVectorBase<{this.ItemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {accessorClassName}> ( + $@"FlatBufferVectorBase<{this.ItemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {accessorClassName}>.GetOrCreate( {context.InputBufferVariableName}, new {accessorClassName}( {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}), {context.InputBufferVariableName}), {context.RemainingDepthVariableName}, - {context.TableFieldContextVariableName})"; + {context.TableFieldContextVariableName}, + {typeof(FlatBufferDeserializationOption).GetGlobalCompilableTypeName()}.{context.Options.DeserializationOption})"; return new CodeGeneratedMethod(CreateParseBody(this.ItemTypeModel, createFlatBufferVector, accessorClassName, context, isEverWriteThrough)) { ClassDefinition = vectorClassDef }; } @@ -136,15 +137,16 @@ public class ListVectorTypeModel : BaseVectorTypeModel if (context.Options.DeserializationOption == FlatBufferDeserializationOption.GreedyMutable && isEverWriteThrough) { string body = $$""" + var result = {{createFlatBufferVector}}; if ({{context.TableFieldContextVariableName}}.{{nameof(TableFieldContext.WriteThrough)}}) { // WriteThrough vectors are not mutable in greedymutable mode. - return result.ToImmutableList(); + return ImmutableList<{{itemTypeModel.ClrType.GetGlobalCompilableTypeName()}}>.GetOrCreate(result); } else { - return result.FlatBufferVectorToList(); + return PoolableList<{{itemTypeModel.ClrType.GetGlobalCompilableTypeName()}}>.GetOrCreate(result); } """; @@ -152,13 +154,13 @@ public class ListVectorTypeModel : BaseVectorTypeModel } else if (context.Options.GreedyDeserialize) { - string transform = "ToImmutableList()"; + string transform = "ImmutableList"; if (context.Options.GenerateMutableObjects) { - transform = "FlatBufferVectorToList()"; + transform = "PoolableList"; } - return $"return ({createFlatBufferVector}).{transform};"; + return $"return {transform}<{itemTypeModel.ClrType.GetGlobalCompilableTypeName()}>.GetOrCreate({createFlatBufferVector});"; } else if (context.Options.Lazy) { @@ -167,7 +169,7 @@ public class ListVectorTypeModel : BaseVectorTypeModel else { FlatSharpInternal.Assert(context.Options.Progressive, "expecting progressive"); - return $"return new FlatBufferProgressiveVector<{itemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {itemAccessorTypeName}>({createFlatBufferVector});"; + return $"return FlatBufferProgressiveVector<{itemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {itemAccessorTypeName}>.GetOrCreate({createFlatBufferVector});"; } } } diff --git a/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs b/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs index 22f697a..84f8183 100644 --- a/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs +++ b/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs @@ -37,14 +37,15 @@ public class ListVectorOfUnionTypeModel : BaseVectorOfUnionTypeModel string itemAccessorTypeName = $"{className}<{context.InputBufferTypeName}>"; string createFlatBufferVector = - $@"new FlatBufferVectorBase<{this.ItemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {itemAccessorTypeName}> ( + $@"FlatBufferVectorBase<{this.ItemTypeModel.GetGlobalCompilableTypeName()}, {context.InputBufferTypeName}, {itemAccessorTypeName}>.GetOrCreate( {context.InputBufferVariableName}, new {itemAccessorTypeName}( {context.InputBufferVariableName}, {context.OffsetVariableName}.offset0 + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}.offset0), {context.OffsetVariableName}.offset1 + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}.offset1)), {context.RemainingDepthVariableName}, - {context.TableFieldContextVariableName})"; + {context.TableFieldContextVariableName}, + {typeof(FlatBufferDeserializationOption).GetGlobalCompilableTypeName()}.{context.Options.DeserializationOption})"; return new CodeGeneratedMethod(ListVectorTypeModel.CreateParseBody( this.ItemTypeModel, diff --git a/src/Tests/FlatSharpEndToEndTests/ClassLib/ImmutableListTests.cs b/src/Tests/FlatSharpEndToEndTests/ClassLib/ImmutableListTests.cs index fda0153..25928b1 100644 --- a/src/Tests/FlatSharpEndToEndTests/ClassLib/ImmutableListTests.cs +++ b/src/Tests/FlatSharpEndToEndTests/ClassLib/ImmutableListTests.cs @@ -16,7 +16,7 @@ using FlatSharp.Internal; -namespace FlatSharpTests; +namespace FlatSharpEndToEndTests.ImmutableList; /// /// Tests for the FlatBufferVector class that implements IList. diff --git a/src/Tests/FlatSharpEndToEndTests/ClassLib/IndexedVectorTests.cs b/src/Tests/FlatSharpEndToEndTests/ClassLib/IndexedVectorTests.cs index 8d70c38..eebfa46 100644 --- a/src/Tests/FlatSharpEndToEndTests/ClassLib/IndexedVectorTests.cs +++ b/src/Tests/FlatSharpEndToEndTests/ClassLib/IndexedVectorTests.cs @@ -24,12 +24,7 @@ public class IndexedVectorTests private List stringKeys; private Container stringVectorSource; - private Container stringVectorLazy; - private Container stringVectorProgressive; - private Container intVectorSource; - private Container intVectorParsed; - private Container intVectorProgressive; public IndexedVectorTests() { @@ -48,60 +43,73 @@ public class IndexedVectorTests this.stringVectorSource.StringVector.Freeze(); this.intVectorSource.IntVector.Freeze(); - - this.stringVectorLazy = stringVectorSource.SerializeAndParse(FlatBufferDeserializationOption.Lazy); - this.stringVectorProgressive = stringVectorSource.SerializeAndParse(FlatBufferDeserializationOption.Progressive); - - this.intVectorParsed = intVectorSource.SerializeAndParse(FlatBufferDeserializationOption.Lazy); - this.intVectorProgressive = intVectorSource.SerializeAndParse(FlatBufferDeserializationOption.Progressive); } - [Fact] - public void IndexedVector_KeyNotFound() + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void IndexedVector_KeyNotFound(FlatBufferDeserializationOption option) { + Container stringValue = this.stringVectorSource.SerializeAndParse(option); + Container intValue = this.intVectorSource.SerializeAndParse(option); + Assert.Throws(() => this.stringVectorSource.StringVector[null]); Assert.Throws(() => this.stringVectorSource.StringVector[string.Empty]); - Assert.Throws(() => this.stringVectorLazy.StringVector[null]); - Assert.Throws(() => this.stringVectorLazy.StringVector[string.Empty]); - Assert.Throws(() => this.stringVectorProgressive.StringVector[null]); - Assert.Throws(() => this.stringVectorProgressive.StringVector[string.Empty]); + Assert.Throws(() => stringValue.StringVector[null]); + Assert.Throws(() => stringValue.StringVector[string.Empty]); Assert.Throws(() => this.intVectorSource.IntVector[int.MinValue]); Assert.Throws(() => this.intVectorSource.IntVector[int.MaxValue]); - Assert.Throws(() => this.intVectorParsed.IntVector[int.MinValue]); - Assert.Throws(() => this.intVectorParsed.IntVector[int.MaxValue]); - Assert.Throws(() => this.intVectorProgressive.IntVector[int.MinValue]); - Assert.Throws(() => this.intVectorProgressive.IntVector[int.MaxValue]); + Assert.Throws(() => intValue.IntVector[int.MinValue]); + Assert.Throws(() => intValue.IntVector[int.MaxValue]); } - [Fact] - public void IndexedVector_NotMutable() + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void IndexedVector_NotMutable(FlatBufferDeserializationOption option) { - Assert.True(this.stringVectorLazy.StringVector.IsReadOnly); + Container stringValue = this.stringVectorSource.SerializeAndParse(option); + + Assert.Equal( + FlatBufferDeserializationOption.GreedyMutable != option, + stringValue.StringVector.IsReadOnly); + Assert.True(this.stringVectorSource.StringVector.IsReadOnly); + // root is frozen. Assert.Throws(() => this.stringVectorSource.StringVector.AddOrReplace(null)); Assert.Throws(() => this.stringVectorSource.StringVector.Clear()); Assert.Throws(() => this.stringVectorSource.StringVector.Remove(null)); Assert.Throws(() => this.stringVectorSource.StringVector.Add(null)); - Assert.Throws(() => this.stringVectorLazy.StringVector.AddOrReplace(null)); - Assert.Throws(() => this.stringVectorLazy.StringVector.Clear()); - Assert.Throws(() => this.stringVectorLazy.StringVector.Remove(null)); - Assert.Throws(() => this.stringVectorLazy.StringVector.Add(null)); + Action[] actions = new Action[] + { + () => stringValue.StringVector.AddOrReplace(new StringKey { Key = "foo" }), + () => stringValue.StringVector.Clear(), + () => stringValue.StringVector.Remove("foo"), + () => stringValue.StringVector.Add(new StringKey { Key = "foo" }), + }; - Assert.Throws(() => this.stringVectorProgressive.StringVector.AddOrReplace(null)); - Assert.Throws(() => this.stringVectorProgressive.StringVector.Clear()); - Assert.Throws(() => this.stringVectorProgressive.StringVector.Remove(null)); - Assert.Throws(() => this.stringVectorProgressive.StringVector.Add(null)); + foreach (var item in actions) + { + if (option == FlatBufferDeserializationOption.GreedyMutable) + { + item(); + } + else + { + Assert.Throws(item); + } + } } - [Fact] - public void IndexedVector_GetEnumerator() + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void IndexedVector_GetEnumerator(FlatBufferDeserializationOption option) { - EnumeratorTest(this.stringVectorLazy.StringVector); + Container stringValue = this.stringVectorSource.SerializeAndParse(option); + + EnumeratorTest(stringValue.StringVector); EnumeratorTest(this.stringVectorSource.StringVector); - EnumeratorTest(this.stringVectorProgressive.StringVector); } private void EnumeratorTest(IIndexedVector vector) @@ -133,37 +141,36 @@ public class IndexedVectorTests Assert.Empty(keys); } - [Fact] - public void IndexedVector_ContainsKey() + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void IndexedVector_ContainsKey(FlatBufferDeserializationOption option) { + Container stringValue = this.stringVectorSource.SerializeAndParse(option); + Assert.True(this.stringVectorSource.StringVector.ContainsKey("1")); Assert.True(this.stringVectorSource.StringVector.ContainsKey("5")); Assert.False(this.stringVectorSource.StringVector.ContainsKey("20")); Assert.Throws(() => this.stringVectorSource.StringVector.ContainsKey(null)); - Assert.True(this.stringVectorLazy.StringVector.ContainsKey("1")); - Assert.True(this.stringVectorLazy.StringVector.ContainsKey("5")); - Assert.False(this.stringVectorLazy.StringVector.ContainsKey("20")); - Assert.Throws(() => this.stringVectorLazy.StringVector.ContainsKey(null)); - - Assert.True(this.stringVectorProgressive.StringVector.ContainsKey("1")); - Assert.True(this.stringVectorProgressive.StringVector.ContainsKey("5")); - Assert.False(this.stringVectorProgressive.StringVector.ContainsKey("20")); - Assert.Throws(() => this.stringVectorProgressive.StringVector.ContainsKey(null)); + Assert.True(stringValue.StringVector.ContainsKey("1")); + Assert.True(stringValue.StringVector.ContainsKey("5")); + Assert.False(stringValue.StringVector.ContainsKey("20")); + Assert.Throws(() => stringValue.StringVector.ContainsKey(null)); } - [Fact] - public void IndexedVector_Caching() + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void IndexedVector_Caching(FlatBufferDeserializationOption option) { + Container stringValue = this.stringVectorSource.SerializeAndParse(option); + foreach (var key in this.stringKeys) { - Assert.True(this.stringVectorLazy.StringVector.TryGetValue(key, out var value)); - Assert.True(this.stringVectorLazy.StringVector.TryGetValue(key, out var value2)); - Assert.NotSame(value, value2); - - Assert.True(this.stringVectorProgressive.StringVector.TryGetValue(key, out value)); - Assert.True(this.stringVectorProgressive.StringVector.TryGetValue(key, out value2)); - Assert.Same(value, value2); + Assert.True(stringValue.StringVector.TryGetValue(key, out var value)); + Assert.True(stringValue.StringVector.TryGetValue(key, out var value2)); + Assert.Equal( + option != FlatBufferDeserializationOption.Lazy, + object.ReferenceEquals(value, value2)); } } diff --git a/src/Tests/FlatSharpEndToEndTests/ClassLib/PoolableListTests.cs b/src/Tests/FlatSharpEndToEndTests/ClassLib/PoolableListTests.cs new file mode 100644 index 0000000..826883b --- /dev/null +++ b/src/Tests/FlatSharpEndToEndTests/ClassLib/PoolableListTests.cs @@ -0,0 +1,148 @@ +/* + * Copyright 2022 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using FlatSharp.Internal; + +namespace FlatSharpEndToEndTests.PoolableList; + +/// +/// Tests for the FlatBufferVector class that implements IList. +/// + +public class PoolableListTests +{ + private PoolableList Items => new PoolableList(new[] { 0, 1, 2, 3, 4, 5, }); + + [Fact] + public void Clear() + { + var items = this.Items; + Assert.NotEmpty(items); + items.Clear(); + Assert.Empty(items); + } + + [Fact] + public void RemoveAt() + { + var items = this.Items; + Assert.Equal(2, items[2]); + Assert.Equal(6, items.Count); + items.RemoveAt(2); + Assert.Equal(3, items[2]); + Assert.Equal(5, items.Count); + } + + [Fact] + public void Remove() + { + var items = this.Items; + Assert.Equal(6, items.Count); + items.Remove(4); + Assert.Equal(5, items.Count); + Assert.Equal(5, items[4]); + } + + [Fact] + public void Setter() + { + var items = this.Items; + Assert.Equal(2, items[2]); + items[2] = 10; + Assert.Equal(10, items[2]); + } + + [Fact] + public void Add() + { + var items = this.Items; + Assert.Equal(6, items.Count); + items.Add(6); + Assert.Equal(7, items.Count); + Assert.Equal(6, items[6]); + } + + [Fact] + public void Insert() + { + var items = this.Items; + Assert.Equal(6, items.Count); + items.Insert(1, 10); + Assert.Equal(7, items.Count); + Assert.Equal(10, items[1]); + Assert.Equal(1, items[2]); + } + + [Fact] + public void Get() + { + var items = this.Items; + for (int i = 0; i < items.Count; ++i) + { + Assert.Equal(i, items[i]); + } + } + + [Fact] + public void Contains() + { + var items = this.Items; + for (int i = 0; i < items.Count; ++i) + { +#pragma warning disable xUnit2017 // Do not use Contains() to check if a value exists in a collection + // Justification: want to ensure correct method is invoked. + + Assert.True(items.Contains(i)); + +#pragma warning restore xUnit2017 // Do not use Contains() to check if a value exists in a collection + + Assert.Equal(i, items.IndexOf(i)); + } + } + + [Fact] + public void GetEnumerator() + { + var items = this.Items; + int i = 0; + foreach (var item in items) + { + Assert.Equal(i++, item); + } + + Assert.Equal(i, items.Count); + } + + [Fact] + public void CopyTo() + { + var items = this.Items; + int[] temp = new int[20]; + items.CopyTo(temp, 10); + + for (int i = 0; i < items.Count; ++i) + { + Assert.Equal(temp[i + 10], items[i]); + } + } + + [Fact] + public void ReadOnly() + { + var items = this.Items; + Assert.False(items.IsReadOnly); + } +} diff --git a/src/Tests/FlatSharpEndToEndTests/ClassLib/VTableTests.cs b/src/Tests/FlatSharpEndToEndTests/ClassLib/VTableTests.cs index 7043009..a67495a 100644 --- a/src/Tests/FlatSharpEndToEndTests/ClassLib/VTableTests.cs +++ b/src/Tests/FlatSharpEndToEndTests/ClassLib/VTableTests.cs @@ -16,7 +16,7 @@ using FlatSharp.Internal; -namespace FlatSharpTests; +namespace FlatSharpEndToEndTests.VTables; public class VTableTests { diff --git a/src/Tests/FlatSharpEndToEndTests/CopyConstructors/CopyConstructorTests.cs b/src/Tests/FlatSharpEndToEndTests/CopyConstructors/CopyConstructorTests.cs index 08834b4..892f57c 100644 --- a/src/Tests/FlatSharpEndToEndTests/CopyConstructors/CopyConstructorTests.cs +++ b/src/Tests/FlatSharpEndToEndTests/CopyConstructors/CopyConstructorTests.cs @@ -19,10 +19,7 @@ namespace FlatSharpEndToEndTests.CopyConstructors; public class CopyConstructorTests { [Theory] - [InlineData(FlatBufferDeserializationOption.Greedy)] - [InlineData(FlatBufferDeserializationOption.GreedyMutable)] - [InlineData(FlatBufferDeserializationOption.Lazy)] - [InlineData(FlatBufferDeserializationOption.Progressive)] + [ClassData(typeof(DeserializationOptionClassData))] public void CopyConstructorsTest(FlatBufferDeserializationOption option) { OuterTable original = new OuterTable diff --git a/src/Tests/FlatSharpEndToEndTests/DeserializationOptionClassData.cs b/src/Tests/FlatSharpEndToEndTests/DeserializationOptionClassData.cs new file mode 100644 index 0000000..c89e9a1 --- /dev/null +++ b/src/Tests/FlatSharpEndToEndTests/DeserializationOptionClassData.cs @@ -0,0 +1,34 @@ +/* + * Copyright 2018 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace FlatSharpEndToEndTests; + +public class DeserializationOptionClassData : IEnumerable +{ + public IEnumerator GetEnumerator() + { + yield return new object[] { FlatBufferDeserializationOption.Lazy }; + yield return new object[] { FlatBufferDeserializationOption.Progressive }; + yield return new object[] { FlatBufferDeserializationOption.Greedy }; + yield return new object[] { FlatBufferDeserializationOption.GreedyMutable }; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } +} + diff --git a/src/Tests/FlatSharpEndToEndTests/IO/InputBufferTests.cs b/src/Tests/FlatSharpEndToEndTests/IO/InputBufferTests.cs index db83ecc..021902a 100644 --- a/src/Tests/FlatSharpEndToEndTests/IO/InputBufferTests.cs +++ b/src/Tests/FlatSharpEndToEndTests/IO/InputBufferTests.cs @@ -46,10 +46,7 @@ public class InputBufferTests } [Theory] - [InlineData(FlatBufferDeserializationOption.Greedy)] - [InlineData(FlatBufferDeserializationOption.GreedyMutable)] - [InlineData(FlatBufferDeserializationOption.Lazy)] - [InlineData(FlatBufferDeserializationOption.Progressive)] + [ClassData(typeof(DeserializationOptionClassData))] public void SerializationInvocations(FlatBufferDeserializationOption option) { var serializer = PrimitiveTypesTable.Serializer.WithSettings(settings => settings.UseDeserializationMode(option)); diff --git a/src/Tests/FlatSharpEndToEndTests/Oracle/OracleDeserializeTests.cs b/src/Tests/FlatSharpEndToEndTests/Oracle/OracleDeserializeTests.cs index 3749a9d..4fe4d40 100644 --- a/src/Tests/FlatSharpEndToEndTests/Oracle/OracleDeserializeTests.cs +++ b/src/Tests/FlatSharpEndToEndTests/Oracle/OracleDeserializeTests.cs @@ -384,8 +384,7 @@ public partial class OracleDeserializeTests } [Theory] - [InlineData(FlatBufferDeserializationOption.Greedy)] - [InlineData(FlatBufferDeserializationOption.Lazy)] + [ClassData(typeof(DeserializationOptionClassData))] public void SortedVectors(FlatBufferDeserializationOption option) { var builder = new FlatBufferBuilder(1024 * 1024); diff --git a/src/Tests/FlatSharpPoolableEndToEndTests/FlatSharpPoolableEndToEndTests.csproj b/src/Tests/FlatSharpPoolableEndToEndTests/FlatSharpPoolableEndToEndTests.csproj new file mode 100644 index 0000000..b97f35a --- /dev/null +++ b/src/Tests/FlatSharpPoolableEndToEndTests/FlatSharpPoolableEndToEndTests.csproj @@ -0,0 +1,52 @@ + + + + + false + true + net7.0;net6.0 + net472;net6.0;net7.0 + net7.0 + false + PoolingTests + FlatSharpTests + true + annotations + CS1591 + $([System.IO.Path]::GetFullPath('$(MSBuildThisFileDirectory)\..\..\FlatSharp.Compiler\bin\$(Configuration)\net7.0\FlatSharp.Compiler.dll')) + true + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + + + + diff --git a/src/Tests/FlatSharpPoolableEndToEndTests/GlobalUsings.cs b/src/Tests/FlatSharpPoolableEndToEndTests/GlobalUsings.cs new file mode 100644 index 0000000..3ab1c59 --- /dev/null +++ b/src/Tests/FlatSharpPoolableEndToEndTests/GlobalUsings.cs @@ -0,0 +1,31 @@ + +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +global using System; +global using System.Buffers; +global using System.Buffers.Binary; +global using System.Collections; +global using System.Collections.Generic; +global using System.ComponentModel; +global using System.Diagnostics; +global using System.Diagnostics.CodeAnalysis; +global using System.Linq; +global using System.Reflection; +global using System.Runtime.CompilerServices; +global using FlatSharp; +global using FlatSharp.Attributes; +global using Xunit; diff --git a/src/Tests/FlatSharpPoolableEndToEndTests/PoolingTests.cs b/src/Tests/FlatSharpPoolableEndToEndTests/PoolingTests.cs new file mode 100644 index 0000000..16fcd2a --- /dev/null +++ b/src/Tests/FlatSharpPoolableEndToEndTests/PoolingTests.cs @@ -0,0 +1,405 @@ +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace FlatSharpEndToEndTests.PoolingTests; + +public class PoolingTests +{ + [Theory] + [InlineData(FlatBufferDeserializationOption.Progressive)] + [InlineData(FlatBufferDeserializationOption.Greedy)] + [InlineData(FlatBufferDeserializationOption.GreedyMutable)] + public void NotLazy_Pools(FlatBufferDeserializationOption option) + { +#if DEBUG + var testPool = new TestObjectPool(); + ObjectPool.Instance = testPool; +#else + ObjectPool.MaxToRetain = 100; +#endif + + byte[] buffer = CreateRoot(1); + + Root parsed = Root.Serializer.Parse(buffer, option); + VerifyRoot(parsed, 1); + + HashSet seenObjects = new(); + + seenObjects.Add(parsed.InnerTable); + seenObjects.Add(parsed.RefStruct); + seenObjects.Add(parsed.VectorOfRefStruct); + seenObjects.Add(parsed.VectorOfTable); + seenObjects.Add(parsed.VectorOfRefStruct); + seenObjects.Add(parsed.VectorOfValueStruct); + seenObjects.Add(parsed.VectorOfUnion); + seenObjects.Add(parsed.IndexedVector); + + foreach (var item in parsed.VectorOfTable) + { + seenObjects.Add(item); + } + + foreach (var item in parsed.VectorOfRefStruct) + { + seenObjects.Add(item); + } + + foreach (var item in parsed.VectorOfUnion) + { + seenObjects.Add(item); + seenObjects.Add(item.Accept(new Visitor())); + } + + foreach (var item in parsed.IndexedVector.Select(x => x.Value)) + { + seenObjects.Add(item); + } + + // Release all our stuff. + parsed.ReturnToPool(); + +#if DEBUG + foreach (var item in seenObjects) + { + Assert.True(testPool.IsInPool(item)); + + int count = testPool.Count(item); + Assert.True(count > 0); + + // return again -- verify that we don't double return. + ((IPoolableObject)item).ReturnToPool(true); + Assert.Equal(count, testPool.Count(item)); + } +#endif + + buffer = CreateRoot(2); + + // Parse again, and ensure that all the things are in the hash set. + parsed = Root.Serializer.Parse(buffer, option); + VerifyRoot(parsed, 2); + + Assert.Contains(parsed.InnerTable, seenObjects); + Assert.Contains(parsed.RefStruct, seenObjects); + Assert.Contains(parsed.VectorOfRefStruct, seenObjects); + Assert.Contains(parsed.VectorOfTable, seenObjects); + Assert.Contains(parsed.VectorOfValueStruct, seenObjects); + Assert.Contains(parsed.VectorOfUnion, seenObjects); + Assert.Contains(parsed.IndexedVector, seenObjects); + + foreach (var item in parsed.VectorOfTable) + { + Assert.Contains(item, seenObjects); + } + + foreach (var item in parsed.VectorOfRefStruct) + { + Assert.Contains(item, seenObjects); + } + + foreach (var item in parsed.VectorOfUnion) + { + Assert.Contains(item, seenObjects); + Assert.Contains(item.Accept(new Visitor()), seenObjects); + } + + foreach (var item in parsed.IndexedVector.Select(x => x.Value)) + { + Assert.Contains(item, seenObjects); + } + +#if DEBUG + foreach (var item in seenObjects) + { + Assert.False(testPool.IsInPool(item)); + Assert.Equal(0, testPool.Count(item)); + } +#endif + } + +#if DEBUG + [Theory] + [InlineData(FlatBufferDeserializationOption.Progressive)] + [InlineData(FlatBufferDeserializationOption.Greedy)] + [InlineData(FlatBufferDeserializationOption.GreedyMutable)] + public void NotLazy_NonRoot_NoOp(FlatBufferDeserializationOption option) + { + var testPool = new TestObjectPool(); + ObjectPool.Instance = testPool; + + void AssertNotInPool(object item) + { + Assert.False(testPool.IsInPool(item)); + + IPoolableObject? obj = item as IPoolableObject; + + Assert.NotNull(obj); + + // Verify call is ignored for non-root objects. + obj.ReturnToPool(); + Assert.False(testPool.IsInPool(item)); + } + + byte[] buffer = CreateRoot(1); + + Root parsed = Root.Serializer.Parse(buffer, option); + VerifyRoot(parsed, 1); + + AssertNotInPool(parsed.InnerTable); + AssertNotInPool(parsed.RefStruct); + AssertNotInPool(parsed.VectorOfRefStruct); + AssertNotInPool(parsed.VectorOfValueStruct); + AssertNotInPool(parsed.VectorOfTable); + AssertNotInPool(parsed.VectorOfUnion); + AssertNotInPool(parsed.IndexedVector); + + foreach (var item in parsed.VectorOfRefStruct) + { + AssertNotInPool(item); + } + + foreach (var item in parsed.VectorOfTable) + { + AssertNotInPool(item); + } + + foreach (var item in parsed.VectorOfUnion) + { + AssertNotInPool(item); + AssertNotInPool(item.Accept(default)); + } + + foreach (var value in parsed.IndexedVector.Select(x => x.Value)) + { + AssertNotInPool(value); + } + } + + [Fact] + public void Lazy_MultipleReturn() + { + var testPool = new TestObjectPool(); + ObjectPool.Instance = testPool; + + byte[] buffer = CreateRoot(1); + byte[] buffer2 = CreateRoot(2); + + Root parsedOriginal = Root.Serializer.Parse(buffer, FlatBufferDeserializationOption.Lazy); + VerifyRoot(parsedOriginal, 1); + + Assert.False(testPool.IsInPool(parsedOriginal)); + Assert.Equal(0, testPool.Count(parsedOriginal)); + + parsedOriginal.ReturnToPool(); + + Assert.True(testPool.IsInPool(parsedOriginal)); + Assert.Equal(1, testPool.Count(parsedOriginal)); + + Root parsed2 = Root.Serializer.Parse(buffer2, FlatBufferDeserializationOption.Lazy); + + Assert.Same(parsedOriginal, parsed2); + VerifyRoot(parsed2, 2); + VerifyRoot(parsedOriginal, 2); + + Root parsed3 = Root.Serializer.Parse(buffer, FlatBufferDeserializationOption.Lazy); + VerifyRoot(parsed3, 1); + + Assert.NotSame(parsedOriginal, parsed3); + + Assert.Equal(0, testPool.Count(parsedOriginal)); + Assert.False(testPool.IsInPool(parsedOriginal)); + Assert.False(testPool.IsInPool(parsed2)); + Assert.False(testPool.IsInPool(parsed3)); + + // Return works. + parsedOriginal.ReturnToPool(); + Assert.Equal(1, testPool.Count(parsedOriginal)); + Assert.True(testPool.IsInPool(parsedOriginal)); + Assert.True(testPool.IsInPool(parsed2)); + Assert.False(testPool.IsInPool(parsed3)); + + // Won't have any effect. + parsed2.ReturnToPool(); + Assert.Equal(1, testPool.Count(parsedOriginal)); + Assert.True(testPool.IsInPool(parsedOriginal)); + Assert.True(testPool.IsInPool(parsed2)); + Assert.False(testPool.IsInPool(parsed3)); + + parsed3.ReturnToPool(); + Assert.Equal(2, testPool.Count(parsedOriginal)); + Assert.True(testPool.IsInPool(parsedOriginal)); + Assert.True(testPool.IsInPool(parsed2)); + Assert.True(testPool.IsInPool(parsed3)); + } + + [Fact] + public void Lazy_Vectors() + { + var testPool = new TestObjectPool(); + ObjectPool.Instance = testPool; + + byte[] buffer = CreateRoot(1); + + Root parsedOriginal = Root.Serializer.Parse(buffer, FlatBufferDeserializationOption.Lazy); + VerifyRoot(parsedOriginal, 1); + + { + var structVector = parsedOriginal.VectorOfRefStruct; + + HashSet seenInstances = new(); + + foreach (RefStruct refStruct in structVector) + { + seenInstances.Add(refStruct); + Assert.False(testPool.IsInPool(refStruct)); + + refStruct.ReturnToPool(); + Assert.True(testPool.IsInPool(refStruct)); + } + + Assert.Single(seenInstances); + } + + { + var unionVector = parsedOriginal.VectorOfUnion; + HashSet seenObjects = new(); + + foreach (var item in unionVector) + { + seenObjects.Add(item); + Assert.False(testPool.IsInPool(item)); + + object value = item.Accept(default); + seenObjects.Add(value); + Assert.False(testPool.IsInPool(value)); + + // returns the underlying item, even in lazy mode. + item.ReturnToPool(); + + Assert.True(testPool.IsInPool(item)); + Assert.True(testPool.IsInPool(value)); + } + + Assert.Equal(3, seenObjects.Count); + } + + { + var indexedVector = parsedOriginal.IndexedVector; + HashSet seenObjects = new(); + + foreach (var item in indexedVector) + { + KeyValue value = item.Value; + + Assert.False(testPool.IsInPool(value)); + seenObjects.Add(value); + + // returns the underlying item, even in lazy mode. + value.ReturnToPool(); + Assert.True(testPool.IsInPool(value)); + } + + indexedVector.ReturnToPool(); + Assert.True(testPool.IsInPool(indexedVector)); + } + } +#endif + + private byte[] CreateRoot(int value) + { + Root root = new() + { + InnerTable = new InnerTable() { X = value }, + RefStruct = new RefStruct() { X = value, }, + ValueStruct = new ValueStruct() { X = value }, + VectorOfRefStruct = new[] { new RefStruct() { X = value }, new() { X = value } }, + VectorOfTable = new[] { new InnerTable() { X = value }, new() { X = value } }, + VectorOfValueStruct = new[] { new ValueStruct { X = value }, new() { X = value } }, + VectorOfUnion = new[] { new InnerUnion(new RefStruct { X = value }), new InnerUnion(new InnerTable { X = value })}, + IndexedVector = new IndexedVector + { + new KeyValue { Key = value, Value = value, }, + new KeyValue { Key = value + 1, Value = value + 1, }, + new KeyValue { Key = value + 2, Value = value + 2, }, + } + }; + + byte[] buffer = new byte[Root.Serializer.GetMaxSize(root)]; + Root.Serializer.Write(buffer, root); + + return buffer; + } + + private static void VerifyRoot(Root item, int expectedValue) + { + Assert.Equal(expectedValue, item.InnerTable.X); + Assert.Equal(expectedValue, item.RefStruct.X); + Assert.Equal(expectedValue, item.ValueStruct.Value.X); + + foreach (var temp in item.VectorOfRefStruct) + { + Assert.Equal(expectedValue, temp.X); + } + + foreach (var temp in item.VectorOfTable) + { + Assert.Equal(expectedValue, temp.X); + } + + foreach (var temp in item.VectorOfValueStruct) + { + Assert.Equal(expectedValue, temp.X); + } + + foreach (var temp in item.VectorOfUnion) + { + if (temp.TryGet(out InnerTable? value)) + { + Assert.Equal(expectedValue, value.X); + } + + if (temp.TryGet(out RefStruct? s)) + { + Assert.Equal(expectedValue, s.X); + } + } + + for (int i = 0; i < 3; ++i) + { + int expected = expectedValue + i; + Assert.True(item.IndexedVector.TryGetValue(expected, out KeyValue? value)); + Assert.Equal(expected, value.Key); + Assert.Equal(expected, value.Value); + } + } + + private struct Visitor : InnerUnion.Visitor + { + public object Visit(RefStruct item) + { + return item; + } + + public object Visit(ValueStruct item) + { + return item; + } + + public object Visit(InnerTable item) + { + return item; + } + } +} diff --git a/src/Tests/FlatSharpPoolableEndToEndTests/PoolingTests.fbs b/src/Tests/FlatSharpPoolableEndToEndTests/PoolingTests.fbs new file mode 100644 index 0000000..9d92e6c --- /dev/null +++ b/src/Tests/FlatSharpPoolableEndToEndTests/PoolingTests.fbs @@ -0,0 +1,35 @@ + +attribute "fs_serializer"; +attribute "fs_unsafeStructVector"; +attribute "fs_valueStruct"; +attribute "fs_writeThrough"; +attribute "fs_vector"; + +namespace FlatSharpEndToEndTests.PoolingTests; + +struct RefStruct { x : int; } +struct ValueStruct (fs_valueStruct) { x : int; } +table InnerTable { x : int; } + +union InnerUnion { RefStruct, ValueStruct, InnerTable } + +table KeyValue +{ + key : int (key); + value : int; +} + +table Root (fs_serializer:"Lazy") +{ + ref_struct : RefStruct; + value_struct : ValueStruct; + inner_table : InnerTable; + + vector_of_ref_struct : [ RefStruct ]; + vector_of_value_struct : [ ValueStruct ]; + vector_of_table : [ InnerTable ]; + + vector_of_union : [ InnerUnion ]; + + indexed_vector : [ KeyValue ] (fs_vector:"IIndexedVector"); +} \ No newline at end of file diff --git a/src/Tests/FlatSharpPoolableEndToEndTests/TestObjectPool.cs b/src/Tests/FlatSharpPoolableEndToEndTests/TestObjectPool.cs new file mode 100644 index 0000000..4ebc51f --- /dev/null +++ b/src/Tests/FlatSharpPoolableEndToEndTests/TestObjectPool.cs @@ -0,0 +1,100 @@ +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using FlatSharp.Internal; +using System.Collections.Concurrent; + +namespace FlatSharpEndToEndTests.PoolingTests; + +#if DEBUG +public class TestObjectPool : IObjectPool +{ + private readonly ConcurrentDictionary pool = new(); + + public void Return(T item) + { + this.GetPool(typeof(T)).Return(item); + } + + public bool TryGet(out T? value) + { + bool result = this.GetPool(typeof(T)).TryGet(out var obj); + + value = (T?)obj; + return result; + } + + public bool IsInPool(object item) + { + return GetPool(item.GetType()).Contains(item); + } + + public int Count(object item) + { + return GetPool(item.GetType()).Count; + } + + private TypedPool GetPool(Type type) + { + return pool.GetOrAdd(type, new TypedPool()); + } + + private class TypedPool + { + private readonly object syncRoot = new(); + private readonly HashSet members = new(); + private readonly Queue dequeuePool = new(); + + public int Count => members.Count; + + public void Return(object item) + { + lock (syncRoot) + { + Assert.True(members.Add(item)); + dequeuePool.Enqueue(item); + } + } + + public bool TryGet(out object? item) + { + lock (syncRoot) + { + if (dequeuePool.Count > 0) + { + item = dequeuePool.Dequeue(); + Assert.True(members.Remove(item)); + return true; + } + else + { + item = default; + return false; + } + } + } + + public bool Contains(object item) + { + lock (syncRoot) + { + return members.Contains(item); + } + } + } +} + +#endif \ No newline at end of file diff --git a/src/Tests/FlatSharpTests/ClassLib/TypeModelTests.cs b/src/Tests/FlatSharpTests/ClassLib/TypeModelTests.cs index 8f46631..2b380aa 100644 --- a/src/Tests/FlatSharpTests/ClassLib/TypeModelTests.cs +++ b/src/Tests/FlatSharpTests/ClassLib/TypeModelTests.cs @@ -253,10 +253,6 @@ public class TypeModelTests () => RuntimeTypeModel.CreateFrom(typeof(TableOnDeserialized_RefParameter))); Assert.Equal(CreateError(), ex.Message); - ex = Assert.Throws( - () => RuntimeTypeModel.CreateFrom(typeof(TableOnDeserialized_OptionalParameter))); - Assert.Equal(CreateError(), ex.Message); - ex = Assert.Throws( () => RuntimeTypeModel.CreateFrom(typeof(TableOnDeserialized_NotProtected))); Assert.Equal(CreateError(), ex.Message); @@ -1398,7 +1394,7 @@ public class TypeModelTests { protected void OnFlatSharpDeserialized(out FlatBufferDeserializationContext context) { - context = null; + context = default; } } @@ -1418,14 +1414,6 @@ public class TypeModelTests } } - [FlatBufferTable] - public class TableOnDeserialized_OptionalParameter - { - protected void OnFlatSharpDeserialized(FlatBufferDeserializationContext context = null) - { - } - } - [FlatBufferTable] public class TableOnDeserialized_Multiple { diff --git a/src/Tests/FlatSharpTests/SerializationTests/ConstructorTests.cs b/src/Tests/FlatSharpTests/SerializationTests/ConstructorTests.cs index 1a09c58..9999d49 100644 --- a/src/Tests/FlatSharpTests/SerializationTests/ConstructorTests.cs +++ b/src/Tests/FlatSharpTests/SerializationTests/ConstructorTests.cs @@ -112,7 +112,7 @@ public class ConstructorTests this.Context = context; } - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } [FlatBufferItem(0)] public virtual OuterStruct? Struct { get; set; } @@ -147,7 +147,7 @@ public class ConstructorTests { } - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } [FlatBufferItem(0)] public virtual int Item { get; set; } @@ -170,7 +170,7 @@ public class ConstructorTests this.Context = context; } - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } [FlatBufferItem(0)] public virtual int Item { get; set; } diff --git a/src/Tests/FlatSharpTests/SerializationTests/DeserializationOptionsTests.cs b/src/Tests/FlatSharpTests/SerializationTests/DeserializationOptionsTests.cs index f03b6ba..327dc84 100644 --- a/src/Tests/FlatSharpTests/SerializationTests/DeserializationOptionsTests.cs +++ b/src/Tests/FlatSharpTests/SerializationTests/DeserializationOptionsTests.cs @@ -255,7 +255,7 @@ public class DeserializationOptionsTests var table = this.SerializeAndParse>(FlatBufferDeserializationOption.GreedyMutable, Strings); InputBuffer.AsSpan().Fill(0); - Assert.Equal(typeof(List), table.Vector.GetType()); + Assert.Equal(typeof(PoolableList), table.Vector.GetType()); Assert.True(object.ReferenceEquals(table.Vector, table.Vector)); var vector = table.Vector; @@ -277,7 +277,7 @@ public class DeserializationOptionsTests var table = this.SerializeAndParse>(FlatBufferDeserializationOption.GreedyMutable, Strings); InputBuffer.AsSpan().Fill(0); - Assert.Equal(typeof(List), table.Vector.GetType()); + Assert.Equal(typeof(PoolableList), table.Vector.GetType()); Assert.True(object.ReferenceEquals(table.Vector, table.Vector)); Assert.True(object.ReferenceEquals(table.Vector[5], table.Vector[5])); Assert.True(object.ReferenceEquals(table.First, table.First)); diff --git a/src/Tests/FlatSharpTests/SerializationTests/DeserializedConstructorTests.cs b/src/Tests/FlatSharpTests/SerializationTests/DeserializedConstructorTests.cs index e2dd1ba..c12a5f6 100644 --- a/src/Tests/FlatSharpTests/SerializationTests/DeserializedConstructorTests.cs +++ b/src/Tests/FlatSharpTests/SerializationTests/DeserializedConstructorTests.cs @@ -107,15 +107,14 @@ public class DeserializedConstructorTests TTable result = serializer.Parse(data); Assert.Null(table.Context); - Assert.Equal(item, result.Context.DeserializationOption); - Assert.Equal(item, result.Struct.Context.DeserializationOption); - Assert.False(object.ReferenceEquals(result.Context, result.Struct.Context)); + Assert.Equal(item, result.Context.Value.DeserializationOption); + Assert.Equal(item, result.Struct.Context.Value.DeserializationOption); } } public interface IContextItem { - FlatBufferDeserializationContext Context { get; } + FlatBufferDeserializationContext? Context { get; } } public interface IContextTable @@ -128,7 +127,7 @@ public class DeserializedConstructorTests public class PublicContextConstructorTable : IContextItem, IContextTable where TStruct : IContextItem, new() { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public PublicContextConstructorTable() { } @@ -147,7 +146,7 @@ public class DeserializedConstructorTests [FlatBufferStruct] public class PublicContextConstructorStruct : IContextItem { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public PublicContextConstructorStruct() { } @@ -165,7 +164,7 @@ public class DeserializedConstructorTests public class ProtectedContextConstructorTable : IContextItem, IContextTable where TStruct : IContextItem, new() { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public ProtectedContextConstructorTable() { } @@ -184,7 +183,7 @@ public class DeserializedConstructorTests [FlatBufferStruct] public class ProtectedContextConstructorStruct : IContextItem { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public ProtectedContextConstructorStruct() { } @@ -201,7 +200,7 @@ public class DeserializedConstructorTests public class ProtectedInternalContextConstructorTable : IContextItem, IContextTable where TStruct : IContextItem, new() { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public ProtectedInternalContextConstructorTable() { } @@ -220,7 +219,7 @@ public class DeserializedConstructorTests [FlatBufferStruct] public class ProtectedInternalContextConstructorStruct : IContextItem { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public ProtectedInternalContextConstructorStruct() { } @@ -238,7 +237,7 @@ public class DeserializedConstructorTests public class PrivateContextConstructorTable : IContextItem, IContextTable where TStruct : IContextItem, new() { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public PrivateContextConstructorTable() { } @@ -257,7 +256,7 @@ public class DeserializedConstructorTests [FlatBufferStruct] public class PrivateContextConstructorStruct : IContextItem { - public FlatBufferDeserializationContext Context { get; } + public FlatBufferDeserializationContext? Context { get; } public PrivateContextConstructorStruct() { } diff --git a/src/Tests/FlatSharpTests/SerializationTests/VectorDeserializationTests.cs b/src/Tests/FlatSharpTests/SerializationTests/VectorDeserializationTests.cs index 351b66c..57f52b5 100644 --- a/src/Tests/FlatSharpTests/SerializationTests/VectorDeserializationTests.cs +++ b/src/Tests/FlatSharpTests/SerializationTests/VectorDeserializationTests.cs @@ -272,7 +272,7 @@ public class VectorDeserializationTests var parsed = serializer.Parse>>(buffer); - Assert.Equal(typeof(List), parsed.Vector.GetType()); + Assert.Equal(typeof(PoolableList), parsed.Vector.GetType()); Assert.False(parsed.Vector.IsReadOnly); // Shouldn't throw.