diff --git a/src/Benchmarks/Benchmark/Internal/SerializationContextBenchmark.cs b/src/Benchmarks/Benchmark/Internal/SerializationContextBenchmark.cs index 19e6490..cace7a6 100644 --- a/src/Benchmarks/Benchmark/Internal/SerializationContextBenchmark.cs +++ b/src/Benchmarks/Benchmark/Internal/SerializationContextBenchmark.cs @@ -18,6 +18,7 @@ namespace Benchmark { using BenchmarkDotNet.Attributes; using FlatSharp; + using FlatSharp.Internal; using System; using System.Collections.Generic; diff --git a/src/FlatSharp.Runtime/FlatSharpInternal.cs b/src/FlatSharp.Runtime/FlatSharpInternal.cs index d5d5c96..e6c8598 100644 --- a/src/FlatSharp.Runtime/FlatSharpInternal.cs +++ b/src/FlatSharp.Runtime/FlatSharpInternal.cs @@ -14,6 +14,8 @@ * limitations under the License. */ +using System.IO; + namespace FlatSharp; internal static class FlatSharpInternal diff --git a/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs b/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs index 3ef82aa..ab39a16 100644 --- a/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs +++ b/src/FlatSharp.Runtime/GeneratedSerializerWrapper.cs @@ -30,6 +30,7 @@ internal class GeneratedSerializerWrapper : ISerializer, ISerializer where private readonly ThreadLocal? sharedStringWriter; private readonly bool enableMemoryCopySerialization; private readonly string? fileIdentifier; + private readonly short remainingDepthLimit; public GeneratedSerializerWrapper( IGeneratedSerializer? innerSerializer, @@ -45,6 +46,7 @@ internal class GeneratedSerializerWrapper : ISerializer, ISerializer where var tableAttribute = typeof(T).GetCustomAttribute(); this.fileIdentifier = tableAttribute?.FileIdentifier; this.sharedStringWriter = new ThreadLocal(() => new SharedStringWriter()); + this.remainingDepthLimit = 1000; // sane default. } private GeneratedSerializerWrapper(GeneratedSerializerWrapper template, SerializerSettings settings) @@ -54,6 +56,7 @@ internal class GeneratedSerializerWrapper : ISerializer, ISerializer where this.AssemblyBytes = template.AssemblyBytes; this.innerSerializer = template.innerSerializer; this.fileIdentifier = template.fileIdentifier; + this.remainingDepthLimit = template.remainingDepthLimit; this.enableMemoryCopySerialization = settings.EnableMemoryCopySerialization; @@ -62,6 +65,16 @@ internal class GeneratedSerializerWrapper : ISerializer, ISerializer where { this.sharedStringWriter = new ThreadLocal(writerFactory); } + + if (settings.ObjectDepthLimit is not null) + { + if (settings.ObjectDepthLimit <= 0) + { + throw new ArgumentException("ObjectDepthLimit must be nonnegative."); + } + + this.remainingDepthLimit = settings.ObjectDepthLimit.Value; + } } Type ISerializer.RootType => typeof(T); @@ -120,7 +133,7 @@ internal class GeneratedSerializerWrapper : ISerializer, ISerializer where } // In case buffer is a reference type or is a boxed value, this allows it the opportunity to "wrap" itself in a value struct for efficiency. - return buffer.InvokeParse(this.innerSerializer, 0); + return buffer.InvokeParse(this.innerSerializer, new GeneratedSerializerParseArguments(0, this.remainingDepthLimit)); } object ISerializer.Parse(TInputBuffer buffer) => this.Parse(buffer); diff --git a/src/FlatSharp.Runtime/IGeneratedSerializer.cs b/src/FlatSharp.Runtime/IGeneratedSerializer.cs index ab8e063..29c9336 100644 --- a/src/FlatSharp.Runtime/IGeneratedSerializer.cs +++ b/src/FlatSharp.Runtime/IGeneratedSerializer.cs @@ -1,5 +1,5 @@ /* - * Copyright 2018 James Courtney + * Copyright 2022 James Courtney * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,23 @@ * limitations under the License. */ -namespace FlatSharp; +namespace FlatSharp.Internal; + +/// +/// Wrapper struct to pass arguments into . +/// +public readonly struct GeneratedSerializerParseArguments +{ + public GeneratedSerializerParseArguments(int offset, short depthLimit) + { + this.Offset = offset; + this.DepthLimit = depthLimit; + } + + public int Offset { get; } + + public short DepthLimit { get; } +} /// /// An interface implemented dynamically by FlatSharp for reading and writing data from a buffer. @@ -49,5 +65,7 @@ public interface IGeneratedSerializer /// /// Parses the given buffer as an instance of from the given offset. /// - T Parse(TInputBuffer buffer, int offset) where TInputBuffer : IInputBuffer; + T Parse( + TInputBuffer buffer, + in GeneratedSerializerParseArguments arguments) where TInputBuffer : IInputBuffer; } diff --git a/src/FlatSharp.Runtime/IO/ArrayInputBuffer.cs b/src/FlatSharp.Runtime/IO/ArrayInputBuffer.cs index c362892..e3cbf60 100644 --- a/src/FlatSharp.Runtime/IO/ArrayInputBuffer.cs +++ b/src/FlatSharp.Runtime/IO/ArrayInputBuffer.cs @@ -143,9 +143,8 @@ public struct ArrayInputBuffer : IInputBuffer, IInputBuffer2 { return this.memory; } - - public T InvokeParse(IGeneratedSerializer serializer, int offset) + public T InvokeParse(IGeneratedSerializer serializer, in GeneratedSerializerParseArguments arguments) { - return serializer.Parse(this, offset); + return serializer.Parse(this, arguments); } } diff --git a/src/FlatSharp.Runtime/IO/ArraySegmentInputBuffer.cs b/src/FlatSharp.Runtime/IO/ArraySegmentInputBuffer.cs index 0534684..7062243 100644 --- a/src/FlatSharp.Runtime/IO/ArraySegmentInputBuffer.cs +++ b/src/FlatSharp.Runtime/IO/ArraySegmentInputBuffer.cs @@ -149,10 +149,9 @@ public struct ArraySegmentInputBuffer : IInputBuffer, IInputBuffer2 { return this.pointer.segment; } - - public T InvokeParse(IGeneratedSerializer serializer, int offset) + public T InvokeParse(IGeneratedSerializer serializer, in GeneratedSerializerParseArguments arguments) { - return serializer.Parse(this, offset); + return serializer.Parse(this, arguments); } // Array Segment is a relatively heavy struct. It contains an array pointer, an int offset, and and int length. diff --git a/src/FlatSharp.Runtime/IO/IInputBuffer.cs b/src/FlatSharp.Runtime/IO/IInputBuffer.cs index f7e0b02..b0e4c93 100644 --- a/src/FlatSharp.Runtime/IO/IInputBuffer.cs +++ b/src/FlatSharp.Runtime/IO/IInputBuffer.cs @@ -97,7 +97,7 @@ public interface IInputBuffer /// Invokes the parse method on the parameter. Allows passing /// generic parameters. /// - TItem InvokeParse(IGeneratedSerializer serializer, int offset); + TItem InvokeParse(IGeneratedSerializer serializer, in GeneratedSerializerParseArguments arguments); } /// diff --git a/src/FlatSharp.Runtime/IO/MemoryInputBuffer.cs b/src/FlatSharp.Runtime/IO/MemoryInputBuffer.cs index e9ffa12..2da7a61 100644 --- a/src/FlatSharp.Runtime/IO/MemoryInputBuffer.cs +++ b/src/FlatSharp.Runtime/IO/MemoryInputBuffer.cs @@ -148,9 +148,9 @@ public struct MemoryInputBuffer : IInputBuffer, IInputBuffer2 return this.pointer.memory; } - public T InvokeParse(IGeneratedSerializer serializer, int offset) + public T InvokeParse(IGeneratedSerializer serializer, in GeneratedSerializerParseArguments arguments) { - return serializer.Parse(this, offset); + return serializer.Parse(this, arguments); } // Memory is a relatively heavy struct. It's cheaper to wrap it in a diff --git a/src/FlatSharp.Runtime/IO/ReadOnlyMemoryInputBuffer.cs b/src/FlatSharp.Runtime/IO/ReadOnlyMemoryInputBuffer.cs index f22786c..539af69 100644 --- a/src/FlatSharp.Runtime/IO/ReadOnlyMemoryInputBuffer.cs +++ b/src/FlatSharp.Runtime/IO/ReadOnlyMemoryInputBuffer.cs @@ -151,9 +151,9 @@ public struct ReadOnlyMemoryInputBuffer : IInputBuffer, IInputBuffer2 throw new InvalidOperationException(ErrorMessage); } - public T InvokeParse(IGeneratedSerializer serializer, int offset) + public T InvokeParse(IGeneratedSerializer serializer, in GeneratedSerializerParseArguments arguments) { - return serializer.Parse(this, offset); + return serializer.Parse(this, arguments); } // Memory is a relatively heavy struct. It's cheaper to wrap it in a diff --git a/src/FlatSharp.Runtime/SerializationContext.cs b/src/FlatSharp.Runtime/SerializationContext.cs index c7eeee7..9d48737 100644 --- a/src/FlatSharp.Runtime/SerializationContext.cs +++ b/src/FlatSharp.Runtime/SerializationContext.cs @@ -16,7 +16,7 @@ using System.Threading; -namespace FlatSharp; +namespace FlatSharp.Internal; /// /// A context object for a FlatBuffer serialize operation. The context is responsible for allocating space in the buffer diff --git a/src/FlatSharp.Runtime/SerializationHelpers.cs b/src/FlatSharp.Runtime/SerializationHelpers.cs index e264838..5ad5284 100644 --- a/src/FlatSharp.Runtime/SerializationHelpers.cs +++ b/src/FlatSharp.Runtime/SerializationHelpers.cs @@ -17,7 +17,7 @@ using System.IO; using System.Text; -namespace FlatSharp; +namespace FlatSharp.Internal; /// /// Collection of methods that help to serialize objects. It's kind of a hodge-podge, @@ -92,4 +92,19 @@ public static class SerializationHelpers { throw new InvalidDataException("FlatSharp encountered a null reference in an invalid context, such as a vector. Vectors are not permitted to have null objects."); } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void EnsureDepthLimit(short remainingDepth) + { + if (remainingDepth < 0) + { + ThrowDepthLimitExceededException(); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowDepthLimitExceededException() + { + throw new InvalidDataException($"FlatSharp passed the configured depth limit when deserializing. This can be configured with 'IGeneratedSerializer.WithSettings'."); + } } diff --git a/src/FlatSharp.Runtime/SerializerSettings.cs b/src/FlatSharp.Runtime/SerializerSettings.cs index 280bf7c..929d52f 100644 --- a/src/FlatSharp.Runtime/SerializerSettings.cs +++ b/src/FlatSharp.Runtime/SerializerSettings.cs @@ -41,4 +41,15 @@ public class SerializerSettings get; set; } + + /// + /// When set, specifies a depth limit for nested objects. Enforced at deserialization time. + /// If set to null, a default value of 1000 will be used. This setting may be used to prevent + /// stack overflow errors and otherwise guard against malicious inputs. + /// + public short? ObjectDepthLimit + { + get; + set; + } } diff --git a/src/FlatSharp.Runtime/Vectors/FlatBufferVector.cs b/src/FlatSharp.Runtime/Vectors/FlatBufferVector.cs index 766457c..ae85d14 100644 --- a/src/FlatSharp.Runtime/Vectors/FlatBufferVector.cs +++ b/src/FlatSharp.Runtime/Vectors/FlatBufferVector.cs @@ -29,7 +29,8 @@ public abstract class FlatBufferVector : FlatBufferVectorBase : FlatBufferVectorBase data); } diff --git a/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs b/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs index 5a731ef..ae1cfd1 100644 --- a/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs +++ b/src/FlatSharp.Runtime/Vectors/FlatBufferVectorBase.cs @@ -24,12 +24,15 @@ public abstract class FlatBufferVectorBase : IList, IReadOnl { protected readonly TInputBuffer memory; protected readonly TableFieldContext fieldContext; + protected readonly short remainingDepth; protected FlatBufferVectorBase( TInputBuffer memory, + short remainingDepth, TableFieldContext fieldContext) { this.memory = memory; + this.remainingDepth = remainingDepth; this.fieldContext = fieldContext; } diff --git a/src/FlatSharp.Runtime/Vectors/FlatBufferVectorOfUnion.cs b/src/FlatSharp.Runtime/Vectors/FlatBufferVectorOfUnion.cs index 981afc3..6f91ee5 100644 --- a/src/FlatSharp.Runtime/Vectors/FlatBufferVectorOfUnion.cs +++ b/src/FlatSharp.Runtime/Vectors/FlatBufferVectorOfUnion.cs @@ -32,7 +32,8 @@ public abstract class FlatBufferVectorOfUnion : FlatBufferVecto TInputBuffer memory, int discriminatorOffset, int offsetVectorOffset, - in TableFieldContext fieldContext) : base(memory, fieldContext) + short remainingDepth, + in TableFieldContext fieldContext) : base(memory, remainingDepth, fieldContext) { uint discriminatorCount = memory.ReadUInt(discriminatorOffset); uint offsetCount = memory.ReadUInt(offsetVectorOffset); @@ -58,6 +59,7 @@ public abstract class FlatBufferVectorOfUnion : FlatBufferVecto this.memory, this.discriminatorVectorOffset + index, this.offsetVectorOffset + (index * sizeof(int)), + base.remainingDepth, this.fieldContext, out item); } @@ -67,6 +69,7 @@ public abstract class FlatBufferVectorOfUnion : FlatBufferVecto TInputBuffer buffer, int discriminatorOffset, int offsetOffset, + short objectDepth, TableFieldContext fieldContext, out T item); } diff --git a/src/FlatSharp/FlatBufferVectorHelpers.cs b/src/FlatSharp/FlatBufferVectorHelpers.cs index f939b1b..c41ba2e 100644 --- a/src/FlatSharp/FlatBufferVectorHelpers.cs +++ b/src/FlatSharp/FlatBufferVectorHelpers.cs @@ -31,6 +31,7 @@ internal static class FlatBufferVectorHelpers InputBufferVariableName = "memory", IsOffsetByRef = false, TableFieldContextVariableName = "fieldContext", + RemainingDepthVariableName = "remainingDepth", }; var serializeContext = parseContext.GetWriteThroughContext("data", "item", "0"); @@ -48,13 +49,15 @@ internal static class FlatBufferVectorHelpers {parseContext.InputBufferTypeName} memory, int offset, int itemSize, - {nameof(TableFieldContext)} fieldContext) : base(memory, offset, itemSize, fieldContext) + short remainingDepth, + {nameof(TableFieldContext)} fieldContext) : base(memory, offset, itemSize, remainingDepth, fieldContext) {{ }} protected override void ParseItem( {parseContext.InputBufferTypeName} memory, int offset, + short remainingDepth, {nameof(TableFieldContext)} fieldContext, out {itemType.GetGlobalCompilableTypeName()} item) {{ @@ -84,6 +87,7 @@ internal static class FlatBufferVectorHelpers IsOffsetByRef = true, TableFieldContextVariableName = "fieldContext", OffsetVariableName = "temp", + RemainingDepthVariableName = "remainingDepth", }; string classDef = $@" @@ -94,7 +98,8 @@ internal static class FlatBufferVectorHelpers {context.InputBufferTypeName} memory, int discriminatorOffset, int offsetVectorOffset, - {nameof(TableFieldContext)} fieldContext) : base(memory, discriminatorOffset, offsetVectorOffset, fieldContext) + short remainingDepth, + {nameof(TableFieldContext)} fieldContext) : base(memory, discriminatorOffset, offsetVectorOffset, remainingDepth, fieldContext) {{ }} @@ -102,6 +107,7 @@ internal static class FlatBufferVectorHelpers {context.InputBufferTypeName} memory, int discriminatorOffset, int offsetOffset, + short remainingDepth, {nameof(TableFieldContext)} {context.TableFieldContextVariableName}, out {typeModel.GetGlobalCompilableTypeName()} item) {{ diff --git a/src/FlatSharp/Serialization/DeserializeClassDefinition.cs b/src/FlatSharp/Serialization/DeserializeClassDefinition.cs index 2e83b5f..e724ef2 100644 --- a/src/FlatSharp/Serialization/DeserializeClassDefinition.cs +++ b/src/FlatSharp/Serialization/DeserializeClassDefinition.cs @@ -23,6 +23,7 @@ internal class DeserializeClassDefinition protected const string InputBufferVariableName = "__buffer"; protected const string OffsetVariableName = "__offset"; protected const string VTableVariableName = "__vtable"; + protected const string RemainingDepthVariableName = "__remainingDepth"; protected readonly ITypeModel typeModel; protected readonly FlatBufferSerializerOptions options; @@ -37,6 +38,7 @@ internal class DeserializeClassDefinition protected readonly MethodInfo? onDeserializeMethod; protected readonly string vtableTypeName; protected readonly string vtableAccessor; + protected readonly string remainingDepthAccessor; private DeserializeClassDefinition( string className, @@ -51,27 +53,35 @@ internal class DeserializeClassDefinition this.vtableTypeName = GetVTableTypeName(maxVtableIndex); this.onDeserializeMethod = onDeserializeMethod; - if (!this.options.GreedyDeserialize) - { - // maintain reference to buffer. - this.instanceFieldDefinitions[InputBufferVariableName] = $"private TInputBuffer {InputBufferVariableName};"; - this.instanceFieldDefinitions[OffsetVariableName] = $"private int {OffsetVariableName};"; - this.initializeStatements.Add($"this.{InputBufferVariableName} = buffer;"); - this.initializeStatements.Add($"this.{OffsetVariableName} = offset;"); - } - this.vtableAccessor = "default"; - if (this.typeModel.SchemaType == FlatBufferSchemaType.Table) + + if (this.options.GreedyDeserialize) { - if (this.options.GreedyDeserialize) + this.remainingDepthAccessor = "remainingDepth"; + + if (this.typeModel.SchemaType == FlatBufferSchemaType.Table) { - // Greedy tables decode a vtable in the constructor but don't stor eit. + // Greedy tables decode a vtable in the constructor but don't store it. this.initializeStatements.Add($"{this.vtableTypeName}.Create(buffer, offset, out var vtable);"); this.vtableAccessor = "vtable"; } - else + } + else + { + this.remainingDepthAccessor = $"this.{RemainingDepthVariableName}"; + + // maintain reference to buffer. + this.instanceFieldDefinitions[InputBufferVariableName] = $"private TInputBuffer {InputBufferVariableName};"; + this.instanceFieldDefinitions[OffsetVariableName] = $"private int {OffsetVariableName};"; + this.instanceFieldDefinitions[RemainingDepthVariableName] = $"private short {RemainingDepthVariableName};"; + + this.initializeStatements.Add($"this.{InputBufferVariableName} = buffer;"); + this.initializeStatements.Add($"this.{OffsetVariableName} = offset;"); + this.initializeStatements.Add($"{this.remainingDepthAccessor} = remainingDepth;"); + + if (this.typeModel.SchemaType == FlatBufferSchemaType.Table) { - // non-greedy tables also carry a vtable. + // Non-greedy tables store the vtable. this.vtableAccessor = $"this.{VTableVariableName}"; this.initializeStatements.Add($"{this.vtableTypeName}.Create(buffer, offset, out {this.vtableAccessor});"); this.instanceFieldDefinitions[VTableVariableName] = $"private {this.vtableTypeName} {VTableVariableName};"; @@ -139,6 +149,7 @@ internal class DeserializeClassDefinition InputBufferTypeName = "TInputBuffer", OffsetVariableName = "offset", InputBufferVariableName = "buffer", + RemainingDepthVariableName = "remainingDepth", }; string body = itemModel.CreateReadItemBody( @@ -151,7 +162,8 @@ internal class DeserializeClassDefinition private static {typeName} {GetReadIndexMethodName(itemModel)}( TInputBuffer buffer, int offset, - {this.vtableTypeName} vtable) + {this.vtableTypeName} vtable, + short remainingDepth) {{ {body} }}"); @@ -227,7 +239,7 @@ internal class DeserializeClassDefinition if (this.options.GreedyDeserialize || !itemModel.IsVirtual) { - this.initializeStatements.Add($"{assignment} = {GetReadIndexMethodName(itemModel)}(buffer, offset, {this.vtableAccessor});"); + this.initializeStatements.Add($"{assignment} = {GetReadIndexMethodName(itemModel)}(buffer, offset, {this.vtableAccessor}, {this.remainingDepthAccessor});"); } else if (!this.options.Lazy) { @@ -271,7 +283,7 @@ internal class DeserializeClassDefinition {string.Join("\r\n", this.instanceFieldDefinitions.Values)} - public static {this.ClassName} GetOrCreate(TInputBuffer buffer, int offset) + public static {this.ClassName} GetOrCreate(TInputBuffer buffer, int offset, short remainingDepth) {{ {this.GetGetOrCreateMethodBody()} }} @@ -325,7 +337,7 @@ internal class DeserializeClassDefinition protected virtual string GetGetterBody(ItemMemberModel itemModel) { - string readUnderlyingInvocation = $"{GetReadIndexMethodName(itemModel)}(this.{InputBufferVariableName}, this.{OffsetVariableName}, {this.vtableAccessor})"; + string readUnderlyingInvocation = $"{GetReadIndexMethodName(itemModel)}(this.{InputBufferVariableName}, this.{OffsetVariableName}, {this.vtableAccessor}, {this.remainingDepthAccessor})"; if (this.options.GreedyDeserialize) { return $"return this.{GetFieldName(itemModel)};"; @@ -350,7 +362,7 @@ internal class DeserializeClassDefinition protected virtual string GetGetOrCreateMethodBody() { return $@" - var item = new {this.ClassName}(buffer, offset); + var item = new {this.ClassName}(buffer, offset, remainingDepth); return item; "; } @@ -358,7 +370,7 @@ internal class DeserializeClassDefinition protected virtual string GetCtorMethodDefinition(string onDeserializedStatement, string baseCtorParams) { return $@" - private {this.ClassName}(TInputBuffer buffer, int offset) : base({baseCtorParams}) + private {this.ClassName}(TInputBuffer buffer, int offset, short remainingDepth) : base({baseCtorParams}) {{ {string.Join("\r\n", this.initializeStatements)} {onDeserializedStatement} diff --git a/src/FlatSharp/Serialization/ParserCodeGenContext.cs b/src/FlatSharp/Serialization/ParserCodeGenContext.cs index ddce54f..e49fe65 100644 --- a/src/FlatSharp/Serialization/ParserCodeGenContext.cs +++ b/src/FlatSharp/Serialization/ParserCodeGenContext.cs @@ -27,6 +27,7 @@ public record ParserCodeGenContext public ParserCodeGenContext( string inputBufferVariableName, string offsetVariableName, + string remainingDepthVariableName, string inputBufferTypeName, bool isOffsetByRef, string tableFieldContextVariableName, @@ -39,6 +40,7 @@ public record ParserCodeGenContext this.InputBufferVariableName = inputBufferVariableName; this.OffsetVariableName = offsetVariableName; this.InputBufferTypeName = inputBufferTypeName; + this.RemainingDepthVariableName = remainingDepthVariableName; this.MethodNameMap = methodNameMap; this.SerializeMethodNameMap = serializeMethodNameMap; this.IsOffsetByRef = isOffsetByRef; @@ -63,6 +65,11 @@ public record ParserCodeGenContext /// public string OffsetVariableName { get; init; } + /// + /// The name of the variable that tracks the remaining depth limit. Decremented down the stack. + /// + public string RemainingDepthVariableName { get; init; } + /// /// Indicates if the offset variable is passed by reference. /// @@ -112,6 +119,8 @@ public record ParserCodeGenContext } sb.Append(this.OffsetVariableName); + sb.Append(", "); + sb.Append(this.RemainingDepthVariableName); if (typeModel.TableFieldContextRequirements.HasFlag(TableFieldContextRequirements.Parse)) { diff --git a/src/FlatSharp/Serialization/RoslynSerializerGenerator.cs b/src/FlatSharp/Serialization/RoslynSerializerGenerator.cs index 20d114e..ed06026 100644 --- a/src/FlatSharp/Serialization/RoslynSerializerGenerator.cs +++ b/src/FlatSharp/Serialization/RoslynSerializerGenerator.cs @@ -173,11 +173,12 @@ $@" this.Write(default!, new byte[10], default!, default!, default!); this.Write(default!, new byte[10], default!, default!, default!); - this.Parse(default!, 0); - this.Parse(default!, 0); - this.Parse(default!, 0); - this.Parse(default!, 0); - this.Parse(default!, 0); + this.Parse(default!, default); + this.Parse(default!, default); + this.Parse(default!, default); + this.Parse(default!, default); + this.Parse(default!, default); + this.Parse(default!, default); throw new InvalidOperationException(""__AotHelper is not intended to be invoked""); }} @@ -396,7 +397,7 @@ $@" /// private void DefineMethods(ITypeModel rootModel) { - HashSet types = new HashSet(); + HashSet types = new(); rootModel.TraverseObjectGraph(types); foreach (var type in types) @@ -413,7 +414,7 @@ $@" var rootModel = this.typeModelContainer.CreateTypeModel(typeof(TRoot)); // all type model types. - HashSet types = new HashSet(); + HashSet types = new(); rootModel.TraverseObjectGraph(types); foreach (var type in types.ToArray()) @@ -485,10 +486,10 @@ $@" { string methodText = $@" - public {CSharpHelpers.GetGlobalCompilableTypeName(rootType)} Parse(TInputBuffer buffer, int offset) + public {CSharpHelpers.GetGlobalCompilableTypeName(rootType)} Parse(TInputBuffer buffer, in {typeof(GeneratedSerializerParseArguments).GetGlobalCompilableTypeName()} args) where TInputBuffer : IInputBuffer {{ - return {this.readMethods[rootType]}(buffer, offset); + return {this.readMethods[rootType]}(buffer, args.{nameof(GeneratedSerializerParseArguments.Offset)}, args.{nameof(GeneratedSerializerParseArguments.DepthLimit)}); }} "; this.methodDeclarations.Add(CSharpSyntaxTree.ParseText(methodText, ParseOptions).GetRoot()); @@ -497,6 +498,7 @@ $@" private void ImplementMethods(ITypeModel rootTypeModel) { + bool requiresDepthTracking = rootTypeModel.IsDeepEnoughToRequireDepthTracking(); List<(ITypeModel, TableFieldContext)> allContexts = rootTypeModel.GetAllTableFieldContexts(); Dictionary> allContextsMap = new(); @@ -531,7 +533,7 @@ $@" : string.Empty; var maxSizeContext = new GetMaxSizeCodeGenContext("value", getMaxSizeFieldContextVariableName, this.maxSizeMethods, this.options, this.typeModelContainer, allContextsMap); - var parseContext = new ParserCodeGenContext("buffer", "offset", "TInputBuffer", isOffsetByRef, parseFieldContextVariableName, this.readMethods, this.writeMethods, this.options, this.typeModelContainer, allContextsMap); + var parseContext = new ParserCodeGenContext("buffer", "offset", "remainingDepth", "TInputBuffer", isOffsetByRef, parseFieldContextVariableName, this.readMethods, this.writeMethods, this.options, this.typeModelContainer, allContextsMap); var serializeContext = new SerializationCodeGenContext("context", "span", "spanWriter", "value", "offset", serializeFieldContextVariableName, isOffsetByRef, this.writeMethods, this.typeModelContainer, this.options, allContextsMap); var maxSizeMethod = typeModel.CreateGetMaxSizeMethodBody(maxSizeContext); @@ -539,7 +541,7 @@ $@" var writeMethod = typeModel.CreateSerializeMethodBody(serializeContext); this.GenerateGetMaxSizeMethod(typeModel, maxSizeMethod, maxSizeContext); - this.GenerateParseMethod(typeModel, parseMethod, parseContext); + this.GenerateParseMethod(requiresDepthTracking, typeModel, parseMethod, parseContext); this.GenerateSerializeMethod(typeModel, writeMethod, serializeContext); string? extraClasses = typeModel.CreateExtraClasses(); @@ -615,7 +617,7 @@ $@" this.AddMethod(method, declaration); } - private void GenerateParseMethod(ITypeModel typeModel, CodeGeneratedMethod method, ParserCodeGenContext context) + private void GenerateParseMethod(bool requiresDepthTracking, ITypeModel typeModel, CodeGeneratedMethod method, ParserCodeGenContext context) { string tableFieldContextParameter = string.Empty; if (typeModel.TableFieldContextRequirements.HasFlag(TableFieldContextRequirements.Parse)) @@ -625,14 +627,26 @@ $@" string clrType = typeModel.GetGlobalCompilableTypeName(); + // If we require depth tracking due to the schema, inject the if statement and the decrement instruction. + string depthCheck = string.Empty; + if (requiresDepthTracking) + { + depthCheck = $@" + --{context.RemainingDepthVariableName}; + {typeof(SerializationHelpers).GetGlobalCompilableTypeName()}.{nameof(SerializationHelpers.EnsureDepthLimit)}({context.RemainingDepthVariableName}); + "; + } + string declaration = $@" {method.GetMethodImplAttribute()} private static {clrType} {this.readMethods[typeModel.ClrType]}( TInputBuffer {context.InputBufferVariableName}, - {GetVTableOffsetVariableType(typeModel.PhysicalLayout.Length)} {context.OffsetVariableName} + {GetVTableOffsetVariableType(typeModel.PhysicalLayout.Length)} {context.OffsetVariableName}, + short {context.RemainingDepthVariableName} {tableFieldContextParameter}) where TInputBuffer : IInputBuffer {{ + {depthCheck} {method.MethodBody} }}"; diff --git a/src/FlatSharp/TypeModel/ITypeModelExtensions.cs b/src/FlatSharp/TypeModel/ITypeModelExtensions.cs index ead293e..6b4f9ae 100644 --- a/src/FlatSharp/TypeModel/ITypeModelExtensions.cs +++ b/src/FlatSharp/TypeModel/ITypeModelExtensions.cs @@ -14,6 +14,8 @@ * limitations under the License. */ +using System.Linq; + namespace FlatSharp.TypeModel; [Flags] @@ -190,16 +192,50 @@ internal static class ITypeModelExtensions } } + /// + /// Indicates if the given type model has enough recursive depth to require object depth tracking (ie, there + /// is a risk of stack overflow). This can be due to an excessively deep object graph or a cycle (we do not care which). + /// + public static bool IsDeepEnoughToRequireDepthTracking(this ITypeModel typeModel) + { + static bool Recurse(ITypeModel model, int depthRemaining) + { + if (depthRemaining <= 0) + { + return true; + } + + foreach (var child in model.Children) + { + if (Recurse(child, depthRemaining - 1)) + { + return true; + } + } + + return false; + } + + return Recurse(typeModel, 500); + } + /// /// Recursively traverses the full object graph for the given type model. /// public static void TraverseObjectGraph(this ITypeModel model, HashSet seenTypes) { - if (seenTypes.Add(model.ClrType)) + Queue discoveryQueue = new(); + discoveryQueue.Enqueue(model); + + while (discoveryQueue.Count > 0) { - foreach (var child in model.Children) + ITypeModel next = discoveryQueue.Dequeue(); + if (seenTypes.Add(next.ClrType)) { - child.TraverseObjectGraph(seenTypes); + foreach (var child in next.Children) + { + discoveryQueue.Enqueue(child); + } } } } diff --git a/src/FlatSharp/TypeModel/StructTypeModel.cs b/src/FlatSharp/TypeModel/StructTypeModel.cs index 22fc104..11de8a0 100644 --- a/src/FlatSharp/TypeModel/StructTypeModel.cs +++ b/src/FlatSharp/TypeModel/StructTypeModel.cs @@ -124,7 +124,7 @@ public class StructTypeModel : RuntimeTypeModel classDef.AddProperty(value, context); } - return new CodeGeneratedMethod($"return {className}<{context.InputBufferTypeName}>.GetOrCreate({context.InputBufferVariableName}, {context.OffsetVariableName});") + return new CodeGeneratedMethod($"return {className}<{context.InputBufferTypeName}>.GetOrCreate({context.InputBufferVariableName}, {context.OffsetVariableName}, {context.RemainingDepthVariableName});") { ClassDefinition = classDef.ToString(), }; diff --git a/src/FlatSharp/TypeModel/TableTypeModel.cs b/src/FlatSharp/TypeModel/TableTypeModel.cs index ee3892b..cb724cf 100644 --- a/src/FlatSharp/TypeModel/TableTypeModel.cs +++ b/src/FlatSharp/TypeModel/TableTypeModel.cs @@ -696,7 +696,7 @@ $@" classDef.AddProperty(value, tempContext); } - string body = $"return {this.tableReaderClassName}<{context.InputBufferTypeName}>.GetOrCreate({context.InputBufferVariableName}, {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}));"; + string body = $"return {this.tableReaderClassName}<{context.InputBufferTypeName}>.GetOrCreate({context.InputBufferVariableName}, {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}), {context.RemainingDepthVariableName});"; return new CodeGeneratedMethod(body) { ClassDefinition = classDef.ToString(), diff --git a/src/FlatSharp/TypeModel/ValueStructTypeModel.cs b/src/FlatSharp/TypeModel/ValueStructTypeModel.cs index b5583eb..60a3019 100644 --- a/src/FlatSharp/TypeModel/ValueStructTypeModel.cs +++ b/src/FlatSharp/TypeModel/ValueStructTypeModel.cs @@ -110,7 +110,8 @@ public class ValueStructTypeModel : RuntimeTypeModel propertyStatements.Add($@" item.{member.accessor} = {context.MethodNameMap[member.model.ClrType]}<{context.InputBufferTypeName}>( {context.InputBufferVariableName}, - {context.OffsetVariableName} + {member.offset});"); + {context.OffsetVariableName} + {member.offset}, + {context.RemainingDepthVariableName});"); } string nonMarshalBody = $@" diff --git a/src/FlatSharp/TypeModel/Vectors/ArrayVectorTypeModel.cs b/src/FlatSharp/TypeModel/Vectors/ArrayVectorTypeModel.cs index d034c86..0451874 100644 --- a/src/FlatSharp/TypeModel/Vectors/ArrayVectorTypeModel.cs +++ b/src/FlatSharp/TypeModel/Vectors/ArrayVectorTypeModel.cs @@ -93,6 +93,7 @@ public class ArrayVectorTypeModel : BaseVectorTypeModel {context.InputBufferVariableName}, {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}), {this.PaddedMemberInlineSize}, + {context.RemainingDepthVariableName}, {context.TableFieldContextVariableName})"; body = $"return ({createFlatBufferVector}).ToArray();"; diff --git a/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs b/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs index 6c43729..0e8a642 100644 --- a/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs +++ b/src/FlatSharp/TypeModel/Vectors/IndexedVectorTypeModel.cs @@ -106,6 +106,7 @@ public class IndexedVectorTypeModel : BaseVectorTypeModel {context.InputBufferVariableName}, {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}), {this.PaddedMemberInlineSize}, + {context.RemainingDepthVariableName}, {context.TableFieldContextVariableName})"; string mutable = context.Options.GenerateMutableObjects.ToString().ToLowerInvariant(); diff --git a/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs b/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs index 72b1c65..380eb5f 100644 --- a/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs +++ b/src/FlatSharp/TypeModel/Vectors/ListVectorTypeModel.cs @@ -114,6 +114,7 @@ public class ListVectorTypeModel : BaseVectorTypeModel {context.InputBufferVariableName}, {context.OffsetVariableName} + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}), {this.PaddedMemberInlineSize}, + {context.RemainingDepthVariableName}, {context.TableFieldContextVariableName})"; return new CodeGeneratedMethod(CreateParseBody(this.ItemTypeModel, createFlatBufferVector, context)) { ClassDefinition = vectorClassDef }; diff --git a/src/FlatSharp/TypeModel/VectorsOfUnion/ArrayVectorOfUnionTypeModel.cs b/src/FlatSharp/TypeModel/VectorsOfUnion/ArrayVectorOfUnionTypeModel.cs index 7b4e7e4..355cc7c 100644 --- a/src/FlatSharp/TypeModel/VectorsOfUnion/ArrayVectorOfUnionTypeModel.cs +++ b/src/FlatSharp/TypeModel/VectorsOfUnion/ArrayVectorOfUnionTypeModel.cs @@ -46,6 +46,7 @@ public class ArrayVectorOfUnionTypeModel : BaseVectorOfUnionTypeModel {context.InputBufferVariableName}, {context.OffsetVariableName}.offset0 + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}.offset0), {context.OffsetVariableName}.offset1 + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}.offset1), + {context.RemainingDepthVariableName}, {context.TableFieldContextVariableName})"; string body = $"return ({createFlatBufferVector}).ToArray();"; diff --git a/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs b/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs index 85ede0e..2716b74 100644 --- a/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs +++ b/src/FlatSharp/TypeModel/VectorsOfUnion/ListVectorOfUnionTypeModel.cs @@ -39,6 +39,7 @@ public class ListVectorOfUnionTypeModel : BaseVectorOfUnionTypeModel {context.InputBufferVariableName}, {context.OffsetVariableName}.offset0 + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}.offset0), {context.OffsetVariableName}.offset1 + {context.InputBufferVariableName}.{nameof(InputBufferExtensions.ReadUOffset)}({context.OffsetVariableName}.offset1), + {context.RemainingDepthVariableName}, {context.TableFieldContextVariableName})"; return new CodeGeneratedMethod(ListVectorTypeModel.CreateParseBody( diff --git a/src/Tests/FlatSharpCompilerTests/DepthLimitTests.cs b/src/Tests/FlatSharpCompilerTests/DepthLimitTests.cs new file mode 100644 index 0000000..ea44bc2 --- /dev/null +++ b/src/Tests/FlatSharpCompilerTests/DepthLimitTests.cs @@ -0,0 +1,124 @@ +/* + * Copyright 2020 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using FlatSharp.TypeModel; +using System.Linq; +using System.Text; + +namespace FlatSharpTests.Compiler; + +public class DepthLimitTests +{ + [Fact] + public void TableReferencesSelf() + { + string fbs = $@" + {MetadataHelpers.AllAttributes} + namespace Foo.Bar; + table ListNode {{ Next : ListNode; }} + "; + + CompileAndVerify(fbs, "Foo.Bar.ListNode", true); + } + + [Fact] + public void TableCycle() + { + string fbs = $@" + {MetadataHelpers.AllAttributes} + namespace Foo.Bar; + table A {{ Next : B; }} + table B {{ Next : C; }} + table C {{ Next : A; }} + "; + + CompileAndVerify(fbs, "Foo.Bar.A", true); + } + + [Theory] + [InlineData(499, false)] + [InlineData(500, true)] + public void DeepTable_NoCycle(int depth, bool expectCycleTracking) + { + StringBuilder sb = new(); + sb.Append($@" + {MetadataHelpers.AllAttributes} + namespace Foo.Bar; + table T0 {{ Value : int; }} + "); + + for (int i = 1; i < depth; ++i) + { + sb.AppendLine($"table T{i} {{ Previous : T{i - 1}; }}"); + } + + CompileAndVerify(sb.ToString(), $"Foo.Bar.T{depth - 1}", expectCycleTracking); + } + + [Fact] + public void TableCycleWithListVector() + { + string fbs = $@" + {MetadataHelpers.AllAttributes} + namespace Foo.Bar; + table A {{ Next : B; }} + table B {{ Next : C; }} + table C {{ Next : [A] (fs_vector:""IList""); }} + "; + + CompileAndVerify(fbs, "Foo.Bar.A", true); + } + + [Fact] + public void TableCycleWithArrayVector() + { + string fbs = $@" + {MetadataHelpers.AllAttributes} + namespace Foo.Bar; + table A {{ Next : B; }} + table B {{ Next : C; }} + table C {{ Next : [A] (fs_vector:""Array""); }} + "; + + CompileAndVerify(fbs, "Foo.Bar.A", true); + } + + /* + [Fact] + public void TableCycleWithIndexedVector() + { + string fbs = $@" + {MetadataHelpers.AllAttributes} + namespace Foo.Bar; + table A {{ Next : B; Key : string (key); }} + table B {{ Next : C; }} + table C {{ Next : [A] (fs_vector:""IIndexedVector""); }} + "; + + CompileAndVerify(fbs, "Foo.Bar.A", true); + } + */ + + private static void CompileAndVerify(string fbs, string typeName, bool needsTracking) + { + Assembly asm = FlatSharpCompiler.CompileAndLoadAssembly(fbs, new()); + Type t = asm.GetTypes().Single(x => x.FullName == typeName); + Assert.NotNull(t); + + ITypeModel typeModel = RuntimeTypeModel.CreateFrom(t); + Assert.Equal(needsTracking, typeModel.IsDeepEnoughToRequireDepthTracking()); + } +} diff --git a/src/Tests/FlatSharpTests/ClassLib/InputBufferTests.cs b/src/Tests/FlatSharpTests/ClassLib/InputBufferTests.cs index 89b6a67..1aba2e3 100644 --- a/src/Tests/FlatSharpTests/ClassLib/InputBufferTests.cs +++ b/src/Tests/FlatSharpTests/ClassLib/InputBufferTests.cs @@ -421,9 +421,9 @@ public class InputBufferTests return ((IInputBuffer)innerBuffer).GetReadOnlyByteMemory(start, length); } - public TItem InvokeParse(IGeneratedSerializer serializer, int offset) + public TItem InvokeParse(IGeneratedSerializer serializer, in GeneratedSerializerParseArguments arguments) { - return ((IInputBuffer)innerBuffer).InvokeParse(serializer, offset); + return ((IInputBuffer)innerBuffer).InvokeParse(serializer, arguments); } public byte ReadByte(int offset) diff --git a/src/Tests/FlatSharpTests/SerializationTests/DepthLimitTests.cs b/src/Tests/FlatSharpTests/SerializationTests/DepthLimitTests.cs new file mode 100644 index 0000000..bfdd52f --- /dev/null +++ b/src/Tests/FlatSharpTests/SerializationTests/DepthLimitTests.cs @@ -0,0 +1,80 @@ +/* + * Copyright 2021 James Courtney + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace FlatSharpTests; + +public class DepthLimitTests +{ + [Fact] + public void InvalidDepthLimit() + { + Assert.Throws(() => + FlatBufferSerializer.Default + .Compile() + .WithSettings(new SerializerSettings { ObjectDepthLimit = -1 })); + } + + [Theory] + [InlineData(FlatBufferDeserializationOption.GreedyMutable, 1000, 999, false)] + [InlineData(FlatBufferDeserializationOption.GreedyMutable, 1000, 1000, true)] + [InlineData(FlatBufferDeserializationOption.Lazy, 1000, 999, false)] + [InlineData(FlatBufferDeserializationOption.Lazy, 1000, 1000, true)] + public void LinkedListDepth(FlatBufferDeserializationOption option, short limit, int nodes, bool expectException) + { + FlatBufferSerializer fbs = new FlatBufferSerializer(option); + ISerializer serializer = fbs.Compile().WithSettings(new SerializerSettings { ObjectDepthLimit = limit }); + + LinkedListNode head = new LinkedListNode(); + LinkedListNode current = head; + + for (int i = 0; i < nodes; ++i) + { + current.Next = new LinkedListNode(); + current = current.Next; + } + + byte[] buffer = new byte[serializer.GetMaxSize(head)]; + serializer.Write(buffer, head); + + Action callback = () => + { + LinkedListNode node = serializer.Parse(buffer); + while (node != null) + { + node = node.Next; + } + }; + + if (expectException) + { + Assert.Throws(callback); + } + else + { + callback(); + } + } + + [FlatBufferTable] + public class LinkedListNode + { + [FlatBufferItem(0)] + public virtual LinkedListNode? Next { get; set; } + + [FlatBufferItem(1)] + public virtual int Value { get; set; } + } +} \ No newline at end of file diff --git a/src/Tests/FlatSharpTests/Util/ContextHelpers.cs b/src/Tests/FlatSharpTests/Util/ContextHelpers.cs index 5311984..68d245a 100644 --- a/src/Tests/FlatSharpTests/Util/ContextHelpers.cs +++ b/src/Tests/FlatSharpTests/Util/ContextHelpers.cs @@ -26,6 +26,7 @@ public static class ContextHelpers return new ParserCodeGenContext( "a", "b", + "e", "c", false, "d", diff --git a/src/common.props b/src/common.props index a345c8c..cd14643 100644 --- a/src/common.props +++ b/src/common.props @@ -9,12 +9,12 @@ true - 6.2.1 - 6.2.1 + 6.3.0 + 6.3.0 $(Version) James Courtney FlatSharp is a fast, idiomatic implementation of the FlatBuffer binary format. - 2021 + 2022 https://github.com/jamescourtney/FlatSharp/ flatbuffers serialization flatbuffer flatsharp Release notes at https://github.com/jamescourtney/FlatSharp/releases