Merge pull request #946 from AArnott/fix924

Fix MessagePackStreamReader reading when string or binary headers are incomplete
This commit is contained in:
Andrew Arnott 2020-06-12 07:03:42 -06:00 коммит произвёл GitHub
Родитель bac7020bcb 7c07da2f4b
Коммит c2ee6bca14
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 92 добавлений и 31 удалений

Просмотреть файл

@ -149,6 +149,7 @@ namespace MessagePack
/// <remarks> /// <remarks>
/// The entire primitive is skipped, including content of maps or arrays, or any other type with payloads. /// The entire primitive is skipped, including content of maps or arrays, or any other type with payloads.
/// To get the raw MessagePack sequence that was skipped, use <see cref="ReadRaw()"/> instead. /// To get the raw MessagePack sequence that was skipped, use <see cref="ReadRaw()"/> instead.
/// WARNING: when false is returned, the position of the reader is undefined.
/// </remarks> /// </remarks>
internal bool TrySkip() internal bool TrySkip()
{ {
@ -187,11 +188,11 @@ namespace MessagePack
case MessagePackCode.Str8: case MessagePackCode.Str8:
case MessagePackCode.Str16: case MessagePackCode.Str16:
case MessagePackCode.Str32: case MessagePackCode.Str32:
return this.reader.TryAdvance(this.GetStringLengthInBytes()); return this.TryGetStringLengthInBytes(out int length) && this.reader.TryAdvance(length);
case MessagePackCode.Bin8: case MessagePackCode.Bin8:
case MessagePackCode.Bin16: case MessagePackCode.Bin16:
case MessagePackCode.Bin32: case MessagePackCode.Bin32:
return this.reader.TryAdvance(this.GetBytesLength()); return this.TryGetBytesLength(out length) && this.reader.TryAdvance(length);
case MessagePackCode.FixExt1: case MessagePackCode.FixExt1:
case MessagePackCode.FixExt2: case MessagePackCode.FixExt2:
case MessagePackCode.FixExt4: case MessagePackCode.FixExt4:
@ -220,7 +221,7 @@ namespace MessagePack
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr) if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{ {
return this.reader.TryAdvance(this.GetStringLengthInBytes()); return this.TryGetStringLengthInBytes(out length) && this.reader.TryAdvance(length);
} }
// We don't actually expect to ever hit this point, since every code is supported. // We don't actually expect to ever hit this point, since every code is supported.
@ -956,77 +957,135 @@ namespace MessagePack
private int GetBytesLength() private int GetBytesLength()
{ {
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte code)); ThrowInsufficientBufferUnless(this.TryGetBytesLength(out int length));
return length;
}
private bool TryGetBytesLength(out int length)
{
if (!this.reader.TryRead(out byte code))
{
length = 0;
return false;
}
// In OldSpec mode, Bin didn't exist, so Str was used. Str8 didn't exist either. // In OldSpec mode, Bin didn't exist, so Str was used. Str8 didn't exist either.
int length;
switch (code) switch (code)
{ {
case MessagePackCode.Bin8: case MessagePackCode.Bin8:
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte byteLength)); if (this.reader.TryRead(out byte byteLength))
length = byteLength; {
length = byteLength;
return true;
}
break; break;
case MessagePackCode.Bin16: case MessagePackCode.Bin16:
case MessagePackCode.Str16: // OldSpec compatibility case MessagePackCode.Str16: // OldSpec compatibility
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out short shortLength)); if (this.reader.TryReadBigEndian(out short shortLength))
length = unchecked((ushort)shortLength); {
length = unchecked((ushort)shortLength);
return true;
}
break; break;
case MessagePackCode.Bin32: case MessagePackCode.Bin32:
case MessagePackCode.Str32: // OldSpec compatibility case MessagePackCode.Str32: // OldSpec compatibility
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out length)); if (this.reader.TryReadBigEndian(out length))
{
return true;
}
break; break;
default: default:
// OldSpec compatibility // OldSpec compatibility
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr) if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{ {
length = code & 0x1F; length = code & 0x1F;
break; return true;
} }
throw ThrowInvalidCode(code); throw ThrowInvalidCode(code);
} }
return length; length = 0;
return false;
}
/// <summary>
/// Gets the length of the next string.
/// </summary>
/// <param name="length">Receives the length of the next string, if there were enough bytes to read it.</param>
/// <returns><c>true</c> if there were enough bytes to read the length of the next string; <c>false</c> otherwise.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool TryGetStringLengthInBytes(out int length)
{
if (!this.reader.TryRead(out byte code))
{
length = 0;
return false;
}
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{
length = code & 0x1F;
return true;
}
return this.TryGetStringLengthInBytesSlow(code, out length);
} }
/// <summary> /// <summary>
/// Gets the length of the next string. /// Gets the length of the next string.
/// </summary> /// </summary>
/// <returns>The length of the next string.</returns> /// <returns>The length of the next string.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private int GetStringLengthInBytes() private int GetStringLengthInBytes()
{ {
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte code)); ThrowInsufficientBufferUnless(this.TryGetStringLengthInBytes(out int length));
return length;
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{
return code & 0x1F;
}
return this.GetStringLengthInBytesSlow(code);
} }
private int GetStringLengthInBytesSlow(byte code) [MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool TryGetStringLengthInBytesSlow(byte code, out int length)
{ {
switch (code) switch (code)
{ {
case MessagePackCode.Str8: case MessagePackCode.Str8:
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte byteValue)); if (this.reader.TryRead(out byte byteValue))
return byteValue; {
length = byteValue;
return true;
}
break;
case MessagePackCode.Str16: case MessagePackCode.Str16:
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out short shortValue)); if (this.reader.TryReadBigEndian(out short shortValue))
return unchecked((ushort)shortValue); {
length = unchecked((ushort)shortValue);
return true;
}
break;
case MessagePackCode.Str32: case MessagePackCode.Str32:
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out int intValue)); if (this.reader.TryReadBigEndian(out int intValue))
return intValue; {
length = intValue;
return true;
}
break;
default: default:
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr) if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{ {
return code & 0x1F; length = code & 0x1F;
return true;
} }
throw ThrowInvalidCode(code); throw ThrowInvalidCode(code);
} }
length = 0;
return false;
} }
/// <summary> /// <summary>

Просмотреть файл

@ -31,9 +31,11 @@ namespace MessagePack.Tests
positions.Add(sequence.AsReadOnlySequence.End); positions.Add(sequence.AsReadOnlySequence.End);
// Second message is more interesting. // Second message is more interesting.
writer.WriteArrayHeader(2); writer.WriteArrayHeader(4);
writer.Write("Hi"); writer.Write("Hi");
writer.Write("There"); writer.Write("There + " + new string('3', 300)); // a long enough string that a multi-byte header is required.
writer.Write(new byte[300]); // a long enough byte array that a multi-byte header is required.
writer.WriteExtensionFormat(new ExtensionResult(1, new byte[300]));
writer.Flush(); writer.Flush();
positions.Add(sequence.AsReadOnlySequence.End); positions.Add(sequence.AsReadOnlySequence.End);