Support 16-bit float in assembler and disassembler
This adds half-precision constants to spirv-tools. 16-bit floats are always disassembled into hex-float format, but can be assembled from floating point or hex-float inputs.
This commit is contained in:
Родитель
b6fe02fc39
Коммит
43401d2ed0
|
@ -193,7 +193,6 @@ It supports the standard `googletest` command line options.
|
|||
|
||||
### Assembler and disassembler
|
||||
|
||||
* Support 16-bit floating point literals.
|
||||
* The disassembler could emit helpful annotations in comments. For example:
|
||||
* Use variable name information from debug instructions to annotate
|
||||
key operations on variables.
|
||||
|
|
|
@ -698,6 +698,10 @@ std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
|
|||
return os;
|
||||
}
|
||||
|
||||
// Parses a floating point number from the given stream and stores it into the
|
||||
// value parameter.
|
||||
// If the negate_value parameter is true then the number is negated before
|
||||
// it is stored into the value parameter.
|
||||
template <typename T, typename Traits>
|
||||
inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
|
||||
HexFloat<T, Traits>& value) {
|
||||
|
@ -710,6 +714,26 @@ inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
|
|||
return is;
|
||||
}
|
||||
|
||||
// Specialization of ParseNormalFloat for FloatProxy<Float16> values.
|
||||
// This will parse the float as it were a 32-bit floating point number,
|
||||
// and then round it down to fit into a Float16 value.
|
||||
// The number is rounded towards zero.
|
||||
// Any floating point number that is too large will be rounded to +- infinity.
|
||||
template <>
|
||||
inline std::istream&
|
||||
ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
|
||||
std::istream& is, bool negate_value,
|
||||
HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
|
||||
float f;
|
||||
is >> f;
|
||||
if (negate_value) {
|
||||
f = -f;
|
||||
}
|
||||
HexFloat<FloatProxy<float>> float_val(f);
|
||||
float_val.castTo(value, round_direction::kToZero);
|
||||
return is;
|
||||
}
|
||||
|
||||
// Reads a HexFloat from the given stream.
|
||||
// If the float is not encoded as a hex-float then it will be parsed
|
||||
// as a regular float.
|
||||
|
@ -940,6 +964,12 @@ std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
|
|||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline std::ostream& operator<< <Float16>(std::ostream& os, const FloatProxy<Float16>& value) {
|
||||
os << HexFloat<FloatProxy<Float16>>(value);
|
||||
return os;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _LIBSPIRV_UTIL_HEX_FLOAT_H_
|
||||
|
|
|
@ -233,9 +233,12 @@ void Disassembler::EmitOperand(const spv_parsed_instruction_t& inst,
|
|||
stream_ << word;
|
||||
break;
|
||||
case SPV_NUMBER_FLOATING:
|
||||
// Assume only 32-bit floats.
|
||||
// TODO(dneto): Handle 16-bit floats also.
|
||||
stream_ << spvutils::FloatProxy<float>(word);
|
||||
if (operand.number_bit_width == 16) {
|
||||
stream_ << spvutils::FloatProxy<spvutils::Float16>(uint16_t(word & 0xFFFF));
|
||||
} else {
|
||||
// Assume 32-bit floats.
|
||||
stream_ << spvutils::FloatProxy<float>(word);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
assert(false && "Unreachable");
|
||||
|
|
|
@ -387,9 +387,18 @@ spv_result_t AssemblyContext::binaryEncodeFloatingPointLiteral(
|
|||
spv_instruction_t* pInst) {
|
||||
const auto bit_width = assumedBitWidth(type);
|
||||
switch (bit_width) {
|
||||
case 16:
|
||||
return diagnostic(SPV_ERROR_INTERNAL)
|
||||
<< "Unsupported yet: 16-bit float constants.";
|
||||
case 16: {
|
||||
spvutils::HexFloat<FloatProxy<spvutils::Float16>> hVal(0);
|
||||
if (auto error = parseNumber(val, error_code, &hVal,
|
||||
"Invalid 16-bit float literal: "))
|
||||
return error;
|
||||
// getAsFloat will return the spvutils::Float16 value, and get_value
|
||||
// will return a uint16_t representing the bits of the float.
|
||||
// The encoding is therefore correct from the perspective of the SPIR-V
|
||||
// spec since the top 16 bits will be 0.
|
||||
return binaryEncodeU32(
|
||||
static_cast<uint32_t>(hVal.value().getAsFloat().get_value()), pInst);
|
||||
} break;
|
||||
case 32: {
|
||||
spvutils::HexFloat<FloatProxy<float>> fVal(0.0f);
|
||||
if (auto error = parseNumber(val, error_code, &fVal,
|
||||
|
|
|
@ -183,7 +183,9 @@ class TextToBinaryTestBase : public T {
|
|||
};
|
||||
|
||||
using TextToBinaryTest = TextToBinaryTestBase<::testing::Test>;
|
||||
|
||||
} // namespace spvtest
|
||||
|
||||
using RoundTripTest =
|
||||
spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
|
||||
|
||||
#endif // LIBSPIRV_TEST_TEST_FIXTURE_H_
|
||||
|
|
|
@ -341,18 +341,10 @@ INSTANTIATE_TEST_CASE_P(
|
|||
}));
|
||||
// clang-format on
|
||||
|
||||
using RoundTripTest =
|
||||
spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
|
||||
|
||||
const int64_t kMaxUnsigned48Bit = (int64_t(1) << 48) - 1;
|
||||
const int64_t kMaxSigned48Bit = (int64_t(1) << 47) - 1;
|
||||
const int64_t kMinSigned48Bit = -kMaxSigned48Bit - 1;
|
||||
|
||||
TEST_P(RoundTripTest, Sample) {
|
||||
EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
|
||||
<< GetParam();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
OpConstantRoundTrip, RoundTripTest,
|
||||
::testing::ValuesIn(std::vector<std::string>{
|
||||
|
@ -396,6 +388,41 @@ INSTANTIATE_TEST_CASE_P(
|
|||
"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -1.79769e+308\n",
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
OpConstantHalfRoundTrip, RoundTripTest,
|
||||
::testing::ValuesIn(std::vector<std::string>{
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x0p+0\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x0p+0\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+0\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.1p+0\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p-1\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.8p+1\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+1\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+0\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.1p+0\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p-1\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.8p+1\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+1\n",
|
||||
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-16\n", // some denorms
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-24\n",
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p-24\n",
|
||||
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+16\n", // +inf
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+16\n", // -inf
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -inf
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p+16\n", // nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.11p+16\n", // nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffp+16\n", // nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+16\n", // nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.004p+16\n", // nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.11p+16\n", // -nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffp+16\n", // -nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+16\n", // -nan
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.004p+16\n", // -nan
|
||||
}));
|
||||
|
||||
// clang-format off
|
||||
// (Clang-format really wants to break up these strings across lines.
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
|
|
|
@ -272,13 +272,6 @@ INSTANTIATE_TEST_CASE_P(
|
|||
MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}),
|
||||
})));
|
||||
|
||||
using RoundTripTest =
|
||||
spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
|
||||
|
||||
TEST_P(RoundTripTest, Sample) {
|
||||
EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
OpSwitchRoundTripUnsignedIntegers, RoundTripTest,
|
||||
::testing::ValuesIn(std::vector<std::string>({
|
||||
|
|
|
@ -245,6 +245,38 @@ INSTANTIATE_TEST_CASE_P(
|
|||
{"!0xff800001", 0xff800001}, // NaN
|
||||
}));
|
||||
|
||||
using TextToBinaryHalfValueTest = spvtest::TextToBinaryTestBase<
|
||||
::testing::TestWithParam<std::pair<std::string, uint32_t>>>;
|
||||
|
||||
TEST_P(TextToBinaryHalfValueTest, Samples) {
|
||||
const std::string input =
|
||||
"%1 = OpTypeFloat 16\n%2 = OpConstant %1 " + GetParam().first;
|
||||
EXPECT_THAT(CompiledInstructions(input),
|
||||
Eq(Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 16}),
|
||||
MakeInstruction(SpvOpConstant,
|
||||
{1, 2, GetParam().second})})));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
HalfValues, TextToBinaryHalfValueTest,
|
||||
::testing::ValuesIn(std::vector<std::pair<std::string, uint32_t>>{
|
||||
{"0.0", 0x00000000},
|
||||
{"1.0", 0x00003c00},
|
||||
{"1.000844", 0x00003c00}, // Truncate to 1.0
|
||||
{"1.000977", 0x00003c01}, // Don't have to truncate
|
||||
{"1.001465", 0x00003c01}, // Truncate to 1.0000977
|
||||
{"1.5", 0x00003e00},
|
||||
{"-1.0", 0x0000bc00},
|
||||
{"2.0", 0x00004000},
|
||||
{"-2.0", 0x0000c000},
|
||||
{"0x1p1", 0x00004000},
|
||||
{"-0x1p1", 0x0000c000},
|
||||
{"0x1.8p1", 0x00004200},
|
||||
{"0x1.8p4", 0x00004e00},
|
||||
{"0x1.801p4", 0x00004e00},
|
||||
{"0x1.804p4", 0x00004e01},
|
||||
}));
|
||||
|
||||
TEST(AssemblyContextParseNarrowSignedIntegers, Sample) {
|
||||
AssemblyContext context(AutoText(""), nullptr);
|
||||
const spv_result_t ec = SPV_FAILED_MATCH;
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "UnitSPIRV.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "TestFixture.h"
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -56,4 +57,9 @@ TEST(WordVectorPrintTo, PreservesFlagsAndFill) {
|
|||
EXPECT_THAT(s.str(), Eq("xx10 0x0000000a 0x00000010 xx11"));
|
||||
}
|
||||
|
||||
TEST_P(RoundTripTest, Sample) {
|
||||
EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
|
||||
<< GetParam();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
|
|
@ -215,5 +215,4 @@ inline std::string MakeLongUTF8String(size_t num_4_byte_chars) {
|
|||
}
|
||||
|
||||
} // namespace spvtest
|
||||
|
||||
#endif // LIBSPIRV_TEST_UNITSPIRV_H_
|
||||
|
|
Загрузка…
Ссылка в новой задаче