Merge pull request #372 from qintao1976/dev
Fix the problem that accumulator can only be used in main thread
This commit is contained in:
Коммит
3f6e998b53
|
@ -75,6 +75,19 @@ namespace Microsoft.Spark.CSharp.Core
|
|||
accumulatorRegistry[accumulatorId] = this;
|
||||
}
|
||||
|
||||
[OnDeserialized()]
|
||||
internal void OnDeserializedMethod(System.Runtime.Serialization.StreamingContext context)
|
||||
{
|
||||
if (threadLocalAccumulatorRegistry == null)
|
||||
{
|
||||
threadLocalAccumulatorRegistry = new Dictionary<int, Accumulator>();
|
||||
}
|
||||
if (!threadLocalAccumulatorRegistry.ContainsKey(accumulatorId))
|
||||
{
|
||||
threadLocalAccumulatorRegistry[accumulatorId] = this;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the value of the accumulator; only usable in driver program
|
||||
/// </summary>
|
||||
|
@ -119,20 +132,6 @@ namespace Microsoft.Spark.CSharp.Core
|
|||
/// <returns></returns>
|
||||
public static Accumulator<T> operator +(Accumulator<T> self, T term)
|
||||
{
|
||||
if (self.isDriver) // this is in driver
|
||||
{
|
||||
if (!accumulatorRegistry.ContainsKey(self.accumulatorId))
|
||||
{
|
||||
accumulatorRegistry[self.accumulatorId] = self;
|
||||
}
|
||||
}
|
||||
else // this is in executor
|
||||
{
|
||||
if (!threadLocalAccumulatorRegistry.ContainsKey(self.accumulatorId))
|
||||
{
|
||||
threadLocalAccumulatorRegistry[self.accumulatorId] = self;
|
||||
}
|
||||
}
|
||||
self.Add(term);
|
||||
return self;
|
||||
}
|
||||
|
|
|
@ -64,14 +64,30 @@ namespace Microsoft.Spark.CSharp.Samples
|
|||
internal class AccumulatorHelper
|
||||
{
|
||||
private Accumulator<int> accumulator;
|
||||
internal AccumulatorHelper(Accumulator<int> accumulator)
|
||||
private bool async;
|
||||
internal AccumulatorHelper(Accumulator<int> accumulator, bool async = false)
|
||||
{
|
||||
this.accumulator = accumulator;
|
||||
this.async = async;
|
||||
}
|
||||
|
||||
internal void Execute(int input)
|
||||
{
|
||||
if (async)
|
||||
{
|
||||
// start new task
|
||||
var task = new Task(() =>
|
||||
{
|
||||
accumulator += input;
|
||||
});
|
||||
task.Start();
|
||||
task.Wait();
|
||||
}
|
||||
else
|
||||
{
|
||||
accumulator += input;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,14 +95,17 @@ namespace Microsoft.Spark.CSharp.Samples
|
|||
internal static void SparkContextAccumulatorSample()
|
||||
{
|
||||
var a = SparkCLRSamples.SparkContext.Accumulator<int>(100);
|
||||
SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(a).Execute);
|
||||
var b = SparkCLRSamples.SparkContext.Accumulator<int>(100);
|
||||
|
||||
Console.WriteLine("accumulator value: " + a.Value);
|
||||
SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(a).Execute);
|
||||
SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(b, true).Execute);
|
||||
Console.WriteLine("accumulator value, a: {0}, b: {1}", a.Value, b.Value);
|
||||
|
||||
if (SparkCLRSamples.Configuration.IsValidationEnabled)
|
||||
{
|
||||
// The value is accumulated on the initial value of the Accumulator which is 100. 110 = 100 + 1 + 2 + 3 + 4
|
||||
Assert.AreEqual(110, a.Value);
|
||||
Assert.AreEqual(110, b.Value);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче