refactoring to change low level enumerator back to normal

This commit is contained in:
renyi 2016-07-28 17:01:07 -07:00
Родитель d20d2812b6
Коммит e38b0c2b98
1 изменённых файлов: 90 добавлений и 192 удалений

Просмотреть файл

@ -254,7 +254,6 @@ namespace Microsoft.Spark.CSharp
if (lengthOfCommandByteArray > 0)
{
var commandProcessWatch = new Stopwatch();
var funcProcessWatch = new Stopwatch();
commandProcessWatch.Start();
int stageId = ReadDiagnosticsInfo(networkStream);
@ -279,30 +278,11 @@ namespace Microsoft.Spark.CSharp
"--------------------------------------------------------------------------------------------------------------");
DateTime initTime = DateTime.UtcNow;
// here we use low level API because we need to get perf metrics
var inputEnumerator = new WorkerInputEnumerator(networkStream, deserializerMode);
IEnumerable<dynamic> inputEnumerable = inputEnumerator.Cast<dynamic>();
funcProcessWatch.Start();
IEnumerable<dynamic> outputEnumerable = func(splitIndex, inputEnumerable);
var outputEnumerator = outputEnumerable.GetEnumerator();
funcProcessWatch.Stop();
int count = 0;
int nullMessageCount = 0;
while (true)
var funcProcessWatch = Stopwatch.StartNew();
foreach (var message in func(splitIndex, GetIterator(networkStream, deserializerMode)))
{
funcProcessWatch.Start();
bool hasNext = outputEnumerator.MoveNext();
funcProcessWatch.Stop();
if (!hasNext)
{
break;
}
funcProcessWatch.Start();
var message = outputEnumerator.Current;
funcProcessWatch.Stop();
if (object.ReferenceEquals(null, message))
@ -313,6 +293,7 @@ namespace Microsoft.Spark.CSharp
WriteOutput(networkStream, serializerMode, message, formatter);
count++;
funcProcessWatch.Start();
}
logger.LogDebug("Output entries count: " + count);
@ -328,7 +309,6 @@ namespace Microsoft.Spark.CSharp
commandProcessWatch.Stop();
// log statistics
inputEnumerator.LogStatistic();
logger.LogInfo(string.Format("func process time: {0}", funcProcessWatch.ElapsedMilliseconds));
logger.LogInfo(string.Format("stage {0}, command process time: {1}", stageId, commandProcessWatch.ElapsedMilliseconds));
}
@ -455,185 +435,103 @@ namespace Microsoft.Spark.CSharp
{
return (long)(dt - UnixTimeEpoch).TotalMilliseconds;
}
}
// Get worker input data from input stream
internal class WorkerInputEnumerator : IEnumerator, IEnumerable
{
private static readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(WorkerInputEnumerator));
private readonly Stream inputStream;
private readonly string deserializedMode;
// cache deserialized object read from input stream
private object[] items = null;
private int pos = 0;
private readonly IFormatter formatter = new BinaryFormatter();
private readonly Stopwatch watch = new Stopwatch();
public WorkerInputEnumerator(Stream inputStream, string deserializedMode)
private static IEnumerable<dynamic> GetIterator(Stream inputStream, string serializedMode)
{
this.inputStream = inputStream;
this.deserializedMode = deserializedMode;
}
public bool MoveNext()
{
watch.Start();
bool hasNext;
if ((items != null) && (pos < items.Length))
logger.LogInfo("Serialized mode in GetIterator: " + serializedMode);
IFormatter formatter = new BinaryFormatter();
var mode = (SerializedMode)Enum.Parse(typeof(SerializedMode), serializedMode);
int messageLength;
Stopwatch watch = Stopwatch.StartNew();
while ((messageLength = SerDe.ReadInt(inputStream)) != (int)SpecialLengths.END_OF_DATA_SECTION)
{
hasNext = true;
}
else
{
int messageLength = SerDe.ReadInt(inputStream);
if (messageLength == (int)SpecialLengths.END_OF_DATA_SECTION)
watch.Stop();
if (messageLength > 0 || messageLength == (int)SpecialLengths.NULL)
{
hasNext = false;
logger.LogDebug("END_OF_DATA_SECTION");
}
else if ((messageLength > 0) || (messageLength == (int)SpecialLengths.NULL))
{
items = GetNext(messageLength);
Debug.Assert(items != null);
Debug.Assert(items.Any());
pos = 0;
hasNext = true;
}
else
{
//unexpected behavior.
//Try moving on to the next item. This might throw exception but
//it might also fetch the next items successfully
//So it is better not to throw exception right way
return MoveNext();
}
}
watch.Stop();
return hasNext;
}
public object Current
{
get
{
int currPos = pos;
pos++;
return items[currPos];
}
}
public void Reset()
{
throw new NotImplementedException();
}
public IEnumerator GetEnumerator()
{
return this;
}
public void LogStatistic()
{
logger.LogInfo(string.Format("total elapsed time: {0}", watch.ElapsedMilliseconds));
}
private object[] GetNext(int messageLength)
{
object[] result = null;
switch ((SerializedMode)Enum.Parse(typeof(SerializedMode), deserializedMode))
{
case SerializedMode.String:
watch.Start();
byte[] buffer = messageLength > 0 ? SerDe.ReadBytes(inputStream, messageLength) : null;
watch.Stop();
switch (mode)
{
result = new object[1];
if (messageLength > 0)
{
byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
if (buffer == null)
case SerializedMode.String:
{
logger.LogDebug("Buffer is null. Message length is {0}", messageLength);
if (messageLength > 0)
{
if (buffer == null)
{
logger.LogDebug("Buffer is null. Message length is {0}", messageLength);
}
yield return SerDe.ToString(buffer);
}
else
{
yield return null;
}
break;
}
case SerializedMode.Row:
{
Debug.Assert(messageLength > 0);
var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
foreach (var row in unpickledObjects.Select(item => (item as RowConstructor).GetRow()))
{
yield return row;
}
break;
}
case SerializedMode.Pair:
{
byte[] pairKey = buffer;
byte[] pairValue = null;
watch.Start();
int valueLength = SerDe.ReadInt(inputStream);
if (valueLength > 0)
{
pairValue = SerDe.ReadBytes(inputStream, valueLength);
}
else if (valueLength == (int)SpecialLengths.NULL)
{
pairValue = null;
}
else
{
throw new Exception(string.Format("unexpected valueLength: {0}", valueLength));
}
watch.Stop();
yield return new KeyValuePair<byte[], byte[]>(pairKey, pairValue);
break;
}
case SerializedMode.None: //just return raw bytes
{
yield return buffer;
break;
}
case SerializedMode.Byte:
default:
{
if (buffer != null)
{
var ms = new MemoryStream(buffer);
yield return formatter.Deserialize(ms);
}
else
{
yield return null;
}
break;
}
result[0] = SerDe.ToString(buffer);
}
else
{
result[0] = null;
}
break;
}
case SerializedMode.Row:
{
Debug.Assert(messageLength > 0);
byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
var rows = unpickledObjects.Select(item => (item as RowConstructor).GetRow()).ToList();
result = rows.Cast<object>().ToArray();
break;
}
case SerializedMode.Pair:
{
byte[] pairKey = (messageLength > 0) ? SerDe.ReadBytes(inputStream, messageLength) : null;
byte[] pairValue = null;
int valueLength = SerDe.ReadInt(inputStream);
if (valueLength > 0)
{
pairValue = SerDe.ReadBytes(inputStream, valueLength);
}
else if (valueLength == (int)SpecialLengths.NULL)
{
pairValue = null;
}
else
{
throw new Exception(string.Format("unexpected valueLength: {0}", valueLength));
}
result = new object[1];
result[0] = new KeyValuePair<byte[], byte[]>(pairKey, pairValue);
break;
}
case SerializedMode.None: //just read raw bytes
{
result = new object[1];
if (messageLength > 0)
{
result[0] = SerDe.ReadBytes(inputStream, messageLength);
}
else
{
result[0] = null;
}
break;
}
case SerializedMode.Byte:
default:
{
result = new object[1];
if (messageLength > 0)
{
byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
var ms = new MemoryStream(buffer);
result[0] = formatter.Deserialize(ms);
}
else
{
result[0] = null;
}
break;
}
}
watch.Start();
}
return result;
logger.LogInfo(string.Format("total receive time: {0}", watch.ElapsedMilliseconds));
}
}
}