Merge pull request #372 from qintao1976/dev

Fix the problem that accumulator can only be used in main thread
This commit is contained in:
Tao Qin 2016-04-05 10:18:27 +08:00
Родитель cdeaf2acc9 e6e5d64070
Коммит 3f6e998b53
2 изменённых файлов: 36 добавлений и 18 удалений

Просмотреть файл

@ -75,6 +75,19 @@ namespace Microsoft.Spark.CSharp.Core
accumulatorRegistry[accumulatorId] = this; accumulatorRegistry[accumulatorId] = this;
} }
[OnDeserialized()]
internal void OnDeserializedMethod(System.Runtime.Serialization.StreamingContext context)
{
if (threadLocalAccumulatorRegistry == null)
{
threadLocalAccumulatorRegistry = new Dictionary<int, Accumulator>();
}
if (!threadLocalAccumulatorRegistry.ContainsKey(accumulatorId))
{
threadLocalAccumulatorRegistry[accumulatorId] = this;
}
}
/// <summary> /// <summary>
/// Gets or sets the value of the accumulator; only usable in driver program /// Gets or sets the value of the accumulator; only usable in driver program
/// </summary> /// </summary>
@ -119,20 +132,6 @@ namespace Microsoft.Spark.CSharp.Core
/// <returns></returns> /// <returns></returns>
public static Accumulator<T> operator +(Accumulator<T> self, T term) public static Accumulator<T> operator +(Accumulator<T> self, T term)
{ {
if (self.isDriver) // this is in driver
{
if (!accumulatorRegistry.ContainsKey(self.accumulatorId))
{
accumulatorRegistry[self.accumulatorId] = self;
}
}
else // this is in executor
{
if (!threadLocalAccumulatorRegistry.ContainsKey(self.accumulatorId))
{
threadLocalAccumulatorRegistry[self.accumulatorId] = self;
}
}
self.Add(term); self.Add(term);
return self; return self;
} }

Просмотреть файл

@ -64,14 +64,30 @@ namespace Microsoft.Spark.CSharp.Samples
internal class AccumulatorHelper internal class AccumulatorHelper
{ {
private Accumulator<int> accumulator; private Accumulator<int> accumulator;
internal AccumulatorHelper(Accumulator<int> accumulator) private bool async;
internal AccumulatorHelper(Accumulator<int> accumulator, bool async = false)
{ {
this.accumulator = accumulator; this.accumulator = accumulator;
this.async = async;
} }
internal void Execute(int input) internal void Execute(int input)
{
if (async)
{
// start new task
var task = new Task(() =>
{ {
accumulator += input; accumulator += input;
});
task.Start();
task.Wait();
}
else
{
accumulator += input;
}
} }
} }
@ -79,14 +95,17 @@ namespace Microsoft.Spark.CSharp.Samples
internal static void SparkContextAccumulatorSample() internal static void SparkContextAccumulatorSample()
{ {
var a = SparkCLRSamples.SparkContext.Accumulator<int>(100); var a = SparkCLRSamples.SparkContext.Accumulator<int>(100);
SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(a).Execute); var b = SparkCLRSamples.SparkContext.Accumulator<int>(100);
Console.WriteLine("accumulator value: " + a.Value); SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(a).Execute);
SparkCLRSamples.SparkContext.Parallelize(new[] { 1, 2, 3, 4 }, 3).Foreach(new AccumulatorHelper(b, true).Execute);
Console.WriteLine("accumulator value, a: {0}, b: {1}", a.Value, b.Value);
if (SparkCLRSamples.Configuration.IsValidationEnabled) if (SparkCLRSamples.Configuration.IsValidationEnabled)
{ {
// The value is accumulated on the initial value of the Accumulator which is 100. 110 = 100 + 1 + 2 + 3 + 4 // The value is accumulated on the initial value of the Accumulator which is 100. 110 = 100 + 1 + 2 + 3 + 4
Assert.AreEqual(110, a.Value); Assert.AreEqual(110, a.Value);
Assert.AreEqual(110, b.Value);
} }
} }