Support 16-bit float in assembler and disassembler

This adds half-precision constants to spirv-tools.
16-bit floats are always disassembled into hex-float format,
but can be assembled from floating point or hex-float inputs.
This commit is contained in:
Andrew Woloszyn 2016-01-08 09:54:42 -05:00 коммит произвёл David Neto
Родитель b6fe02fc39
Коммит 43401d2ed0
10 изменённых файлов: 124 добавлений и 24 удалений

Просмотреть файл

@ -193,7 +193,6 @@ It supports the standard `googletest` command line options.
### Assembler and disassembler
* Support 16-bit floating point literals.
* The disassembler could emit helpful annotations in comments. For example:
* Use variable name information from debug instructions to annotate
key operations on variables.

Просмотреть файл

@ -698,6 +698,10 @@ std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
return os;
}
// Parses a floating point number from the given stream and stores it into the
// value parameter.
// If the negate_value parameter is true then the number is negated before
// it is stored into the value parameter.
template <typename T, typename Traits>
inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
HexFloat<T, Traits>& value) {
@ -710,6 +714,26 @@ inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
return is;
}
// Specialization of ParseNormalFloat for FloatProxy<Float16> values.
// This will parse the float as it were a 32-bit floating point number,
// and then round it down to fit into a Float16 value.
// The number is rounded towards zero.
// Any floating point number that is too large will be rounded to +- infinity.
template <>
inline std::istream&
ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
std::istream& is, bool negate_value,
HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
float f;
is >> f;
if (negate_value) {
f = -f;
}
HexFloat<FloatProxy<float>> float_val(f);
float_val.castTo(value, round_direction::kToZero);
return is;
}
// Reads a HexFloat from the given stream.
// If the float is not encoded as a hex-float then it will be parsed
// as a regular float.
@ -940,6 +964,12 @@ std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
}
return os;
}
template<>
inline std::ostream& operator<< <Float16>(std::ostream& os, const FloatProxy<Float16>& value) {
os << HexFloat<FloatProxy<Float16>>(value);
return os;
}
}
#endif // _LIBSPIRV_UTIL_HEX_FLOAT_H_

Просмотреть файл

@ -233,9 +233,12 @@ void Disassembler::EmitOperand(const spv_parsed_instruction_t& inst,
stream_ << word;
break;
case SPV_NUMBER_FLOATING:
// Assume only 32-bit floats.
// TODO(dneto): Handle 16-bit floats also.
stream_ << spvutils::FloatProxy<float>(word);
if (operand.number_bit_width == 16) {
stream_ << spvutils::FloatProxy<spvutils::Float16>(uint16_t(word & 0xFFFF));
} else {
// Assume 32-bit floats.
stream_ << spvutils::FloatProxy<float>(word);
}
break;
default:
assert(false && "Unreachable");

Просмотреть файл

@ -387,9 +387,18 @@ spv_result_t AssemblyContext::binaryEncodeFloatingPointLiteral(
spv_instruction_t* pInst) {
const auto bit_width = assumedBitWidth(type);
switch (bit_width) {
case 16:
return diagnostic(SPV_ERROR_INTERNAL)
<< "Unsupported yet: 16-bit float constants.";
case 16: {
spvutils::HexFloat<FloatProxy<spvutils::Float16>> hVal(0);
if (auto error = parseNumber(val, error_code, &hVal,
"Invalid 16-bit float literal: "))
return error;
// getAsFloat will return the spvutils::Float16 value, and get_value
// will return a uint16_t representing the bits of the float.
// The encoding is therefore correct from the perspective of the SPIR-V
// spec since the top 16 bits will be 0.
return binaryEncodeU32(
static_cast<uint32_t>(hVal.value().getAsFloat().get_value()), pInst);
} break;
case 32: {
spvutils::HexFloat<FloatProxy<float>> fVal(0.0f);
if (auto error = parseNumber(val, error_code, &fVal,

Просмотреть файл

@ -183,7 +183,9 @@ class TextToBinaryTestBase : public T {
};
using TextToBinaryTest = TextToBinaryTestBase<::testing::Test>;
} // namespace spvtest
using RoundTripTest =
spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
#endif // LIBSPIRV_TEST_TEST_FIXTURE_H_

Просмотреть файл

@ -341,18 +341,10 @@ INSTANTIATE_TEST_CASE_P(
}));
// clang-format on
using RoundTripTest =
spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
const int64_t kMaxUnsigned48Bit = (int64_t(1) << 48) - 1;
const int64_t kMaxSigned48Bit = (int64_t(1) << 47) - 1;
const int64_t kMinSigned48Bit = -kMaxSigned48Bit - 1;
TEST_P(RoundTripTest, Sample) {
EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
<< GetParam();
}
INSTANTIATE_TEST_CASE_P(
OpConstantRoundTrip, RoundTripTest,
::testing::ValuesIn(std::vector<std::string>{
@ -396,6 +388,41 @@ INSTANTIATE_TEST_CASE_P(
"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -1.79769e+308\n",
}));
INSTANTIATE_TEST_CASE_P(
OpConstantHalfRoundTrip, RoundTripTest,
::testing::ValuesIn(std::vector<std::string>{
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x0p+0\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x0p+0\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+0\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.1p+0\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p-1\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.8p+1\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+1\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+0\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.1p+0\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p-1\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.8p+1\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+1\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-16\n", // some denorms
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-24\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p-24\n",
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+16\n", // +inf
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+16\n", // -inf
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -inf
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p+16\n", // nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.11p+16\n", // nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffp+16\n", // nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+16\n", // nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.004p+16\n", // nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.11p+16\n", // -nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffp+16\n", // -nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+16\n", // -nan
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.004p+16\n", // -nan
}));
// clang-format off
// (Clang-format really wants to break up these strings across lines.
INSTANTIATE_TEST_CASE_P(

Просмотреть файл

@ -272,13 +272,6 @@ INSTANTIATE_TEST_CASE_P(
MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}),
})));
using RoundTripTest =
spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
TEST_P(RoundTripTest, Sample) {
EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()));
}
INSTANTIATE_TEST_CASE_P(
OpSwitchRoundTripUnsignedIntegers, RoundTripTest,
::testing::ValuesIn(std::vector<std::string>({

Просмотреть файл

@ -245,6 +245,38 @@ INSTANTIATE_TEST_CASE_P(
{"!0xff800001", 0xff800001}, // NaN
}));
using TextToBinaryHalfValueTest = spvtest::TextToBinaryTestBase<
::testing::TestWithParam<std::pair<std::string, uint32_t>>>;
TEST_P(TextToBinaryHalfValueTest, Samples) {
const std::string input =
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 " + GetParam().first;
EXPECT_THAT(CompiledInstructions(input),
Eq(Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 16}),
MakeInstruction(SpvOpConstant,
{1, 2, GetParam().second})})));
}
INSTANTIATE_TEST_CASE_P(
HalfValues, TextToBinaryHalfValueTest,
::testing::ValuesIn(std::vector<std::pair<std::string, uint32_t>>{
{"0.0", 0x00000000},
{"1.0", 0x00003c00},
{"1.000844", 0x00003c00}, // Truncate to 1.0
{"1.000977", 0x00003c01}, // Don't have to truncate
{"1.001465", 0x00003c01}, // Truncate to 1.0000977
{"1.5", 0x00003e00},
{"-1.0", 0x0000bc00},
{"2.0", 0x00004000},
{"-2.0", 0x0000c000},
{"0x1p1", 0x00004000},
{"-0x1p1", 0x0000c000},
{"0x1.8p1", 0x00004200},
{"0x1.8p4", 0x00004e00},
{"0x1.801p4", 0x00004e00},
{"0x1.804p4", 0x00004e01},
}));
TEST(AssemblyContextParseNarrowSignedIntegers, Sample) {
AssemblyContext context(AutoText(""), nullptr);
const spv_result_t ec = SPV_FAILED_MATCH;

Просмотреть файл

@ -27,6 +27,7 @@
#include "UnitSPIRV.h"
#include "gmock/gmock.h"
#include "TestFixture.h"
namespace {
@ -56,4 +57,9 @@ TEST(WordVectorPrintTo, PreservesFlagsAndFill) {
EXPECT_THAT(s.str(), Eq("xx10 0x0000000a 0x00000010 xx11"));
}
TEST_P(RoundTripTest, Sample) {
EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
<< GetParam();
}
} // anonymous namespace

Просмотреть файл

@ -215,5 +215,4 @@ inline std::string MakeLongUTF8String(size_t num_4_byte_chars) {
}
} // namespace spvtest
#endif // LIBSPIRV_TEST_UNITSPIRV_H_