diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Core/Accumulator.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Core/Accumulator.cs index 0d26b06..816002e 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Core/Accumulator.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Core/Accumulator.cs @@ -75,6 +75,19 @@ namespace Microsoft.Spark.CSharp.Core accumulatorRegistry[accumulatorId] = this; } + [OnDeserialized()] + internal void OnDeserializedMethod(System.Runtime.Serialization.StreamingContext context) + { + if (threadLocalAccumulatorRegistry == null) + { + threadLocalAccumulatorRegistry = new Dictionary(); + } + if (!threadLocalAccumulatorRegistry.ContainsKey(accumulatorId)) + { + threadLocalAccumulatorRegistry[accumulatorId] = this; + } + } + /// /// Gets or sets the value of the accumulator; only usable in driver program /// @@ -119,20 +132,6 @@ namespace Microsoft.Spark.CSharp.Core /// public static Accumulator operator +(Accumulator self, T term) { - if (self.isDriver) // this is in driver - { - if (!accumulatorRegistry.ContainsKey(self.accumulatorId)) - { - accumulatorRegistry[self.accumulatorId] = self; - } - } - else // this is in executor - { - if (!threadLocalAccumulatorRegistry.ContainsKey(self.accumulatorId)) - { - threadLocalAccumulatorRegistry[self.accumulatorId] = self; - } - } self.Add(term); return self; } diff --git a/csharp/Samples/Microsoft.Spark.CSharp/SparkContextSamples.cs b/csharp/Samples/Microsoft.Spark.CSharp/SparkContextSamples.cs index c60fcfb..2769134 100644 --- a/csharp/Samples/Microsoft.Spark.CSharp/SparkContextSamples.cs +++ b/csharp/Samples/Microsoft.Spark.CSharp/SparkContextSamples.cs @@ -64,14 +64,30 @@ namespace Microsoft.Spark.CSharp.Samples internal class AccumulatorHelper { private Accumulator accumulator; - internal AccumulatorHelper(Accumulator accumulator) + private bool async; + internal AccumulatorHelper(Accumulator accumulator, bool async = false) { this.accumulator = accumulator; + this.async = async; } internal void Execute(int input) { - accumulator += input; + if (async) + { + // start new task + var task = new Task(() => + { + accumulator += input; + }); + task.Start(); + task.Wait(); + } + else + { + accumulator += input; + } + } } @@ -79,14 +95,17 @@ namespace Microsoft.Spark.CSharp.Samples internal static void SparkContextAccumulatorSample() { var a = SparkCLRSamples.SparkContext.Accumulator(100); - SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(a).Execute); + var b = SparkCLRSamples.SparkContext.Accumulator(100); - Console.WriteLine("accumulator value: " + a.Value); + SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(a).Execute); + SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(b, true).Execute); + Console.WriteLine("accumulator value, a: {0}, b: {1}", a.Value, b.Value); if (SparkCLRSamples.Configuration.IsValidationEnabled) { // The value is accumulated on the initial value of the Accumulator which is 100. 110 = 100 + 1 + 2 + 3 + 4 Assert.AreEqual(110, a.Value); + Assert.AreEqual(110, b.Value); } }