refactoring to change low level enumerator back to normal
This commit is contained in:
Родитель
d20d2812b6
Коммит
e38b0c2b98
|
@ -254,7 +254,6 @@ namespace Microsoft.Spark.CSharp
|
|||
if (lengthOfCommandByteArray > 0)
|
||||
{
|
||||
var commandProcessWatch = new Stopwatch();
|
||||
var funcProcessWatch = new Stopwatch();
|
||||
commandProcessWatch.Start();
|
||||
|
||||
int stageId = ReadDiagnosticsInfo(networkStream);
|
||||
|
@ -279,30 +278,11 @@ namespace Microsoft.Spark.CSharp
|
|||
"--------------------------------------------------------------------------------------------------------------");
|
||||
DateTime initTime = DateTime.UtcNow;
|
||||
|
||||
// here we use low level API because we need to get perf metrics
|
||||
var inputEnumerator = new WorkerInputEnumerator(networkStream, deserializerMode);
|
||||
IEnumerable<dynamic> inputEnumerable = inputEnumerator.Cast<dynamic>();
|
||||
|
||||
funcProcessWatch.Start();
|
||||
IEnumerable<dynamic> outputEnumerable = func(splitIndex, inputEnumerable);
|
||||
var outputEnumerator = outputEnumerable.GetEnumerator();
|
||||
funcProcessWatch.Stop();
|
||||
|
||||
int count = 0;
|
||||
int nullMessageCount = 0;
|
||||
while (true)
|
||||
var funcProcessWatch = Stopwatch.StartNew();
|
||||
foreach (var message in func(splitIndex, GetIterator(networkStream, deserializerMode)))
|
||||
{
|
||||
funcProcessWatch.Start();
|
||||
bool hasNext = outputEnumerator.MoveNext();
|
||||
funcProcessWatch.Stop();
|
||||
|
||||
if (!hasNext)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
funcProcessWatch.Start();
|
||||
var message = outputEnumerator.Current;
|
||||
funcProcessWatch.Stop();
|
||||
|
||||
if (object.ReferenceEquals(null, message))
|
||||
|
@ -313,6 +293,7 @@ namespace Microsoft.Spark.CSharp
|
|||
|
||||
WriteOutput(networkStream, serializerMode, message, formatter);
|
||||
count++;
|
||||
funcProcessWatch.Start();
|
||||
}
|
||||
|
||||
logger.LogDebug("Output entries count: " + count);
|
||||
|
@ -328,7 +309,6 @@ namespace Microsoft.Spark.CSharp
|
|||
commandProcessWatch.Stop();
|
||||
|
||||
// log statistics
|
||||
inputEnumerator.LogStatistic();
|
||||
logger.LogInfo(string.Format("func process time: {0}", funcProcessWatch.ElapsedMilliseconds));
|
||||
logger.LogInfo(string.Format("stage {0}, command process time: {1}", stageId, commandProcessWatch.ElapsedMilliseconds));
|
||||
}
|
||||
|
@ -455,185 +435,103 @@ namespace Microsoft.Spark.CSharp
|
|||
{
|
||||
return (long)(dt - UnixTimeEpoch).TotalMilliseconds;
|
||||
}
|
||||
}
|
||||
|
||||
// Get worker input data from input stream
|
||||
internal class WorkerInputEnumerator : IEnumerator, IEnumerable
|
||||
{
|
||||
private static readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(WorkerInputEnumerator));
|
||||
|
||||
private readonly Stream inputStream;
|
||||
private readonly string deserializedMode;
|
||||
|
||||
// cache deserialized object read from input stream
|
||||
private object[] items = null;
|
||||
private int pos = 0;
|
||||
|
||||
private readonly IFormatter formatter = new BinaryFormatter();
|
||||
private readonly Stopwatch watch = new Stopwatch();
|
||||
|
||||
public WorkerInputEnumerator(Stream inputStream, string deserializedMode)
|
||||
private static IEnumerable<dynamic> GetIterator(Stream inputStream, string serializedMode)
|
||||
{
|
||||
this.inputStream = inputStream;
|
||||
this.deserializedMode = deserializedMode;
|
||||
}
|
||||
|
||||
public bool MoveNext()
|
||||
{
|
||||
watch.Start();
|
||||
bool hasNext;
|
||||
|
||||
if ((items != null) && (pos < items.Length))
|
||||
logger.LogInfo("Serialized mode in GetIterator: " + serializedMode);
|
||||
IFormatter formatter = new BinaryFormatter();
|
||||
var mode = (SerializedMode)Enum.Parse(typeof(SerializedMode), serializedMode);
|
||||
int messageLength;
|
||||
Stopwatch watch = Stopwatch.StartNew();
|
||||
while ((messageLength = SerDe.ReadInt(inputStream)) != (int)SpecialLengths.END_OF_DATA_SECTION)
|
||||
{
|
||||
hasNext = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
int messageLength = SerDe.ReadInt(inputStream);
|
||||
if (messageLength == (int)SpecialLengths.END_OF_DATA_SECTION)
|
||||
watch.Stop();
|
||||
if (messageLength > 0 || messageLength == (int)SpecialLengths.NULL)
|
||||
{
|
||||
hasNext = false;
|
||||
logger.LogDebug("END_OF_DATA_SECTION");
|
||||
}
|
||||
else if ((messageLength > 0) || (messageLength == (int)SpecialLengths.NULL))
|
||||
{
|
||||
items = GetNext(messageLength);
|
||||
Debug.Assert(items != null);
|
||||
Debug.Assert(items.Any());
|
||||
pos = 0;
|
||||
hasNext = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
//unexpected behavior.
|
||||
//Try moving on to the next item. This might throw exception but
|
||||
//it might also fetch the next items successfully
|
||||
//So it is better not to throw exception right way
|
||||
return MoveNext();
|
||||
}
|
||||
}
|
||||
|
||||
watch.Stop();
|
||||
return hasNext;
|
||||
}
|
||||
|
||||
public object Current
|
||||
{
|
||||
get
|
||||
{
|
||||
int currPos = pos;
|
||||
pos++;
|
||||
return items[currPos];
|
||||
}
|
||||
}
|
||||
|
||||
public void Reset()
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public IEnumerator GetEnumerator()
|
||||
{
|
||||
return this;
|
||||
}
|
||||
|
||||
public void LogStatistic()
|
||||
{
|
||||
logger.LogInfo(string.Format("total elapsed time: {0}", watch.ElapsedMilliseconds));
|
||||
}
|
||||
|
||||
private object[] GetNext(int messageLength)
|
||||
{
|
||||
object[] result = null;
|
||||
switch ((SerializedMode)Enum.Parse(typeof(SerializedMode), deserializedMode))
|
||||
{
|
||||
case SerializedMode.String:
|
||||
watch.Start();
|
||||
byte[] buffer = messageLength > 0 ? SerDe.ReadBytes(inputStream, messageLength) : null;
|
||||
watch.Stop();
|
||||
switch (mode)
|
||||
{
|
||||
result = new object[1];
|
||||
if (messageLength > 0)
|
||||
{
|
||||
byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
|
||||
if (buffer == null)
|
||||
case SerializedMode.String:
|
||||
{
|
||||
logger.LogDebug("Buffer is null. Message length is {0}", messageLength);
|
||||
if (messageLength > 0)
|
||||
{
|
||||
if (buffer == null)
|
||||
{
|
||||
logger.LogDebug("Buffer is null. Message length is {0}", messageLength);
|
||||
}
|
||||
yield return SerDe.ToString(buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
yield return null;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.Row:
|
||||
{
|
||||
Debug.Assert(messageLength > 0);
|
||||
var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
|
||||
foreach (var row in unpickledObjects.Select(item => (item as RowConstructor).GetRow()))
|
||||
{
|
||||
yield return row;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.Pair:
|
||||
{
|
||||
byte[] pairKey = buffer;
|
||||
byte[] pairValue = null;
|
||||
|
||||
watch.Start();
|
||||
int valueLength = SerDe.ReadInt(inputStream);
|
||||
if (valueLength > 0)
|
||||
{
|
||||
pairValue = SerDe.ReadBytes(inputStream, valueLength);
|
||||
}
|
||||
else if (valueLength == (int)SpecialLengths.NULL)
|
||||
{
|
||||
pairValue = null;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception(string.Format("unexpected valueLength: {0}", valueLength));
|
||||
}
|
||||
watch.Stop();
|
||||
|
||||
yield return new KeyValuePair<byte[], byte[]>(pairKey, pairValue);
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.None: //just return raw bytes
|
||||
{
|
||||
yield return buffer;
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.Byte:
|
||||
default:
|
||||
{
|
||||
if (buffer != null)
|
||||
{
|
||||
var ms = new MemoryStream(buffer);
|
||||
yield return formatter.Deserialize(ms);
|
||||
}
|
||||
else
|
||||
{
|
||||
yield return null;
|
||||
}
|
||||
break;
|
||||
}
|
||||
result[0] = SerDe.ToString(buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
result[0] = null;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.Row:
|
||||
{
|
||||
Debug.Assert(messageLength > 0);
|
||||
byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
|
||||
var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
|
||||
var rows = unpickledObjects.Select(item => (item as RowConstructor).GetRow()).ToList();
|
||||
result = rows.Cast<object>().ToArray();
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.Pair:
|
||||
{
|
||||
byte[] pairKey = (messageLength > 0) ? SerDe.ReadBytes(inputStream, messageLength) : null;
|
||||
byte[] pairValue = null;
|
||||
|
||||
int valueLength = SerDe.ReadInt(inputStream);
|
||||
if (valueLength > 0)
|
||||
{
|
||||
pairValue = SerDe.ReadBytes(inputStream, valueLength);
|
||||
}
|
||||
else if (valueLength == (int)SpecialLengths.NULL)
|
||||
{
|
||||
pairValue = null;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception(string.Format("unexpected valueLength: {0}", valueLength));
|
||||
}
|
||||
|
||||
result = new object[1];
|
||||
result[0] = new KeyValuePair<byte[], byte[]>(pairKey, pairValue);
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.None: //just read raw bytes
|
||||
{
|
||||
result = new object[1];
|
||||
if (messageLength > 0)
|
||||
{
|
||||
result[0] = SerDe.ReadBytes(inputStream, messageLength);
|
||||
}
|
||||
else
|
||||
{
|
||||
result[0] = null;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SerializedMode.Byte:
|
||||
default:
|
||||
{
|
||||
result = new object[1];
|
||||
if (messageLength > 0)
|
||||
{
|
||||
byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
|
||||
var ms = new MemoryStream(buffer);
|
||||
result[0] = formatter.Deserialize(ms);
|
||||
}
|
||||
else
|
||||
{
|
||||
result[0] = null;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
watch.Start();
|
||||
}
|
||||
|
||||
return result;
|
||||
logger.LogInfo(string.Format("total receive time: {0}", watch.ElapsedMilliseconds));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче