diff --git a/src/OrleansCodeGenerator/SerializerGenerator.cs b/src/OrleansCodeGenerator/SerializerGenerator.cs index 1a07c23ef..8ed050f64 100644 --- a/src/OrleansCodeGenerator/SerializerGenerator.cs +++ b/src/OrleansCodeGenerator/SerializerGenerator.cs @@ -600,18 +600,22 @@ namespace Orleans.CodeGenerator this.FieldInfo.FieldType); // If the value is not a GrainReference, convert it to a strongly-typed GrainReference. - // C#: !(value is GrainReference) ? value.AsReference() : value; + // C#: (value == null || value is GrainReference) ? value : value.AsReference() deepCopyValueExpression = SF.ConditionalExpression( - SF.PrefixUnaryExpression( - SyntaxKind.LogicalNotExpression, - SF.ParenthesizedExpression( + SF.ParenthesizedExpression( + SF.BinaryExpression( + SyntaxKind.LogicalOrExpression, + SF.BinaryExpression( + SyntaxKind.EqualsExpression, + getValueExpression, + SF.LiteralExpression(SyntaxKind.NullLiteralExpression)), SF.BinaryExpression( SyntaxKind.IsExpression, getValueExpression, typeof(GrainReference).GetTypeSyntax()))), - SF.InvocationExpression(getAsReference), - getValueExpression); + getValueExpression, + SF.InvocationExpression(getAsReference)); } else { diff --git a/test/DefaultCluster.Tests/GrainReferenceTest.cs b/test/DefaultCluster.Tests/GrainReferenceTest.cs index 00338f212..3dcdbee89 100644 --- a/test/DefaultCluster.Tests/GrainReferenceTest.cs +++ b/test/DefaultCluster.Tests/GrainReferenceTest.cs @@ -63,6 +63,19 @@ namespace DefaultCluster.Tests.General g1.PassThisNested(new ChainGrainHolder { Next = g2 }).Wait(); } + [Fact, TestCategory("BVT"), TestCategory("Functional"), TestCategory("GrainReference")] + public async Task GrainReference_Pass_Null() + { + IChainedGrain g1 = this.GrainFactory.GetGrain(GetRandomGrainId()); + IChainedGrain g2 = this.GrainFactory.GetGrain(GetRandomGrainId()); + + // g1 will pass a null reference to g2 + await g1.PassNullNested(new ChainGrainHolder { Next = g2 }); + Assert.Null(await g2.GetNext()); + await g1.PassNull(g2); + Assert.Null(await g2.GetNext()); + } + [Fact, TestCategory("BVT"), TestCategory("Functional"), TestCategory("Serialization"), TestCategory("GrainReference")] public void GrainReference_DotNet_Serialization() { diff --git a/test/NonSiloTests/Serialization/BuiltInSerializerTests.cs b/test/NonSiloTests/Serialization/BuiltInSerializerTests.cs index f259309b1..66e7d5050 100644 --- a/test/NonSiloTests/Serialization/BuiltInSerializerTests.cs +++ b/test/NonSiloTests/Serialization/BuiltInSerializerTests.cs @@ -135,6 +135,7 @@ namespace UnitTests.Serialization Assert.Equal(expected.Classes[1].Interfaces[0].Int, actual.Classes[1].Interfaces[0].Int); Assert.Equal(0, actual.NonSerializedInt); Assert.Equal(expected.GetObsoleteInt(), actual.GetObsoleteInt()); + Assert.Null(actual.SomeGrainReference); } [Theory, TestCategory("BVT"), TestCategory("Serialization")] @@ -169,6 +170,7 @@ namespace UnitTests.Serialization Assert.Equal(expected.ReadonlyField, actual.ReadonlyField); Assert.Equal(expected.PublicValue, actual.PublicValue); Assert.Equal(expected.ValueWithPrivateSetter, actual.ValueWithPrivateSetter); + Assert.Null(actual.SomeGrainReference); Assert.Equal(expected.GetPrivateValue(), actual.GetPrivateValue()); Assert.Equal(expected.GetValueWithPrivateGetter(), actual.GetValueWithPrivateGetter()); } diff --git a/test/TestGrainInterfaces/CodegenTestInterfaces.cs b/test/TestGrainInterfaces/CodegenTestInterfaces.cs index 83a6a8d19..1de2f55ad 100644 --- a/test/TestGrainInterfaces/CodegenTestInterfaces.cs +++ b/test/TestGrainInterfaces/CodegenTestInterfaces.cs @@ -127,6 +127,8 @@ namespace UnitTests.GrainInterfaces private int PrivateValue { get; set; } public readonly int ReadonlyField; + public IEchoGrain SomeGrainReference { get; set; } + public SomeStruct(int readonlyField) : this() { @@ -173,9 +175,10 @@ namespace UnitTests.GrainInterfaces [Obsolete("This field should be serialized")] public int ObsoleteInt { get; set; } - - #pragma warning disable 618 + public IEchoGrain SomeGrainReference { get; set; } + +#pragma warning disable 618 public int GetObsoleteInt() => this.ObsoleteInt; public void SetObsoleteInt(int value) { diff --git a/test/TestGrainInterfaces/IChainedGrain.cs b/test/TestGrainInterfaces/IChainedGrain.cs index 0ad2e230f..7fc6a87d9 100644 --- a/test/TestGrainInterfaces/IChainedGrain.cs +++ b/test/TestGrainInterfaces/IChainedGrain.cs @@ -15,7 +15,9 @@ namespace UnitTests.GrainInterfaces //[ReadOnly] Task Validate(bool nextIsSet); Task PassThis(IChainedGrain next); + Task PassNull(IChainedGrain next); Task PassThisNested(ChainGrainHolder next); + Task PassNullNested(ChainGrainHolder next); } public class ChainGrainHolder diff --git a/test/TestGrains/ChainedGrain.cs b/test/TestGrains/ChainedGrain.cs index 657e9bb93..5544d077f 100644 --- a/test/TestGrains/ChainedGrain.cs +++ b/test/TestGrains/ChainedGrain.cs @@ -85,11 +85,21 @@ namespace UnitTests.Grains return next.SetNext(this); } + public Task PassNull(IChainedGrain next) + { + return next.SetNext(null); + } + public Task PassThisNested(ChainGrainHolder next) { return next.Next.SetNextNested(new ChainGrainHolder { Next = this }); } + public Task PassNullNested(ChainGrainHolder next) + { + return next.Next.SetNextNested(new ChainGrainHolder { Next = null }); + } + #endregion } }