Support folding OpBitcast with numeric constants (#4247)

Add constant folding rule for OpBitcast with numeric scalar or vector
constants.
This commit is contained in:
Jaebaek Seo 2021-04-27 14:24:46 -04:00 коммит произвёл GitHub
Родитель 6cdf07d2b3
Коммит 07ec4f83c5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 249 добавлений и 14 удалений

Просмотреть файл

@ -389,6 +389,36 @@ const Constant* ConstantManager::GetConstant(
return cst ? RegisterConstant(std::move(cst)) : nullptr;
}
const Constant* ConstantManager::GetNumericVectorConstantWithWords(
const Vector* type, const std::vector<uint32_t>& literal_words) {
const auto* element_type = type->element_type();
uint32_t words_per_element = 0;
if (const auto* float_type = element_type->AsFloat())
words_per_element = float_type->width() / 32;
else if (const auto* int_type = element_type->AsInteger())
words_per_element = int_type->width() / 32;
if (words_per_element != 1 && words_per_element != 2) return nullptr;
if (words_per_element * type->element_count() !=
static_cast<uint32_t>(literal_words.size())) {
return nullptr;
}
std::vector<uint32_t> element_ids;
for (uint32_t i = 0; i < type->element_count(); ++i) {
auto first_word = literal_words.begin() + (words_per_element * i);
std::vector<uint32_t> const_data(first_word,
first_word + words_per_element);
const analysis::Constant* element_constant =
GetConstant(element_type, const_data);
auto element_id = GetDefiningInstruction(element_constant)->result_id();
element_ids.push_back(element_id);
}
return GetConstant(type, element_ids);
}
uint32_t ConstantManager::GetFloatConst(float val) {
Type* float_type = context()->get_type_mgr()->GetFloatType();
utils::FloatProxy<float> v(val);

Просмотреть файл

@ -506,10 +506,11 @@ class ConstantManager {
IRContext* context() const { return ctx_; }
// Gets or creates a unique Constant instance of type |type| and a vector of
// constant defining words |words|. If a Constant instance existed already in
// the constant pool, it returns a pointer to it. Otherwise, it creates one
// using CreateConstant. If a new Constant instance cannot be created, it
// returns nullptr.
// constant defining words or ids for elements of Vector type
// |literal_words_or_ids|. If a Constant instance existed already in the
// constant pool, it returns a pointer to it. Otherwise, it creates one using
// CreateConstant. If a new Constant instance cannot be created, it returns
// nullptr.
const Constant* GetConstant(
const Type* type, const std::vector<uint32_t>& literal_words_or_ids);
@ -519,6 +520,14 @@ class ConstantManager {
literal_words_or_ids.end()));
}
// Gets or creates a unique Constant instance of Vector type |type| with
// numeric elements and a vector of constant defining words |literal_words|.
// If a Constant instance existed already in the constant pool, it returns a
// pointer to it. Otherwise, it creates one using CreateConstant. If a new
// Constant instance cannot be created, it returns nullptr.
const Constant* GetNumericVectorConstantWithWords(
const Vector* type, const std::vector<uint32_t>& literal_words);
// Gets or creates a Constant instance to hold the constant value of the given
// instruction. It returns a pointer to a Constant instance or nullptr if it
// could not create the constant.

Просмотреть файл

@ -124,6 +124,66 @@ Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
inst->GetSingleWordInOperand(in_op));
}
std::vector<uint32_t> ExtractInts(uint64_t val) {
std::vector<uint32_t> words;
words.push_back(static_cast<uint32_t>(val));
words.push_back(static_cast<uint32_t>(val >> 32));
return words;
}
std::vector<uint32_t> GetWordsFromScalarIntConstant(
const analysis::IntConstant* c) {
assert(c != nullptr);
uint32_t width = c->type()->AsInteger()->width();
assert(width == 32 || width == 64);
if (width == 64) {
uint64_t uval = static_cast<uint64_t>(c->GetU64());
return ExtractInts(uval);
}
return {c->GetU32()};
}
std::vector<uint32_t> GetWordsFromScalarFloatConstant(
const analysis::FloatConstant* c) {
assert(c != nullptr);
uint32_t width = c->type()->AsFloat()->width();
assert(width == 32 || width == 64);
if (width == 64) {
utils::FloatProxy<double> result(c->GetDouble());
return result.GetWords();
}
utils::FloatProxy<float> result(c->GetFloat());
return result.GetWords();
}
std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
if (const auto* float_constant = c->AsFloatConstant()) {
return GetWordsFromScalarFloatConstant(float_constant);
} else if (const auto* int_constant = c->AsIntConstant()) {
return GetWordsFromScalarIntConstant(int_constant);
} else if (const auto* vec_constant = c->AsVectorConstant()) {
std::vector<uint32_t> words;
for (const auto* comp : vec_constant->GetComponents()) {
auto comp_in_words =
GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
}
return words;
}
return {};
}
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
const analysis::Type* type) {
if (type->AsInteger() || type->AsFloat())
return const_mgr->GetConstant(type, words);
if (const auto* vec_type = type->AsVector())
return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
return nullptr;
}
// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
// constant.
uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
@ -146,13 +206,6 @@ uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
return const_mgr->GetDefiningInstruction(negated_const)->result_id();
}
std::vector<uint32_t> ExtractInts(uint64_t val) {
std::vector<uint32_t> words;
words.push_back(static_cast<uint32_t>(val));
words.push_back(static_cast<uint32_t>(val >> 32));
return words;
}
// Negates the integer constant |c|. Returns the id of the defining instruction.
uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
const analysis::Constant* c) {
@ -1796,6 +1849,33 @@ FoldingRule RedundantPhi() {
};
}
FoldingRule BitCastScalarOrVector() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
if (constants[0] == nullptr) return false;
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
std::vector<uint32_t> words =
GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
if (words.size() == 0) return false;
const analysis::Constant* bitcasted_constant =
ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
auto new_feeder_id =
const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
->result_id();
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
return true;
};
}
FoldingRule RedundantSelect() {
// An OpSelect instruction where both values are the same or the condition is
// constant can be replaced by one of the values
@ -2423,6 +2503,8 @@ void FoldingRules::AddFoldingRules() {
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());

Просмотреть файл

@ -141,6 +141,7 @@ OpName %main "main"
%v4int = OpTypeVector %int 4
%v4float = OpTypeVector %float 4
%v4double = OpTypeVector %double 4
%v2uint = OpTypeVector %uint 2
%v2float = OpTypeVector %float 2
%v2double = OpTypeVector %double 2
%v2half = OpTypeVector %half 2
@ -191,6 +192,7 @@ OpName %main "main"
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
%v2int_min_max = OpConstantComposite %v2int %int_min %int_max
%v2bool_null = OpConstantNull %v2bool
%v2bool_true_false = OpConstantComposite %v2bool %true %false
%v2bool_false_true = OpConstantComposite %v2bool %false %true
@ -258,6 +260,15 @@ OpName %main "main"
%v4double_1_1_1_0p5 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_0p5
%v4double_null = OpConstantNull %v4double
%v4float_n1_2_1_3 = OpConstantComposite %v4float %float_n1 %float_2 %float_1 %float_3
%uint_0x3f800000 = OpConstant %uint 0x3f800000
%uint_0xbf800000 = OpConstant %uint 0xbf800000
%v2uint_0x3f800000_0xbf800000 = OpConstantComposite %v2uint %uint_0x3f800000 %uint_0xbf800000
%long_0xbf8000003f800000 = OpConstant %long 0xbf8000003f800000
%int_0x3FF00000 = OpConstant %int 0x3FF00000
%int_0x00000000 = OpConstant %int 0x00000000
%int_0xC05FD666 = OpConstant %int 0xC05FD666
%int_0x66666666 = OpConstant %int 0x66666666
%v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666 = OpConstantComposite %v4int %int_0x00000000 %int_0x3FF00000 %int_0x66666666 %int_0xC05FD666
)";
return header;
@ -708,7 +719,31 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
"%2 = OpExtInst %uint %1 UClamp %uint_2 %undef %uint_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1)
2, 1),
// Test case 46: Bit-cast int 0 to unsigned int
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %uint %int_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 47: Bit-cast int -24 to unsigned int
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %uint %int_n24\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, static_cast<uint32_t>(-24)),
// Test case 48: Bit-cast float 1.0f to unsigned int
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %uint %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, static_cast<uint32_t>(0x3f800000))
));
// clang-format on
@ -790,10 +825,72 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntVectorInstructionFoldingTest,
"%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 4294967295 \n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {0,0})
2, {0,0}),
// Test case 4: fold bit-cast int -24 to unsigned int
InstructionFoldingCase<std::vector<uint32_t>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %int %n\n" +
"%2 = OpBitcast %v2uint %v2int_min_max\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {2147483648, 2147483647})
));
// clang-format on
using DoubleVectorInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<std::vector<double>>>;
TEST_P(DoubleVectorInstructionFoldingTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_TRUE(succeeded);
if (succeeded && inst != nullptr) {
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
EXPECT_THAT(opcodes, Contains(inst->opcode()));
analysis::ConstantManager* const_mrg = context->get_constant_mgr();
const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
EXPECT_NE(result, nullptr);
if (result != nullptr) {
const std::vector<const analysis::Constant*>& componenets =
result->AsVectorConstant()->GetComponents();
EXPECT_EQ(componenets.size(), tc.expected_result.size());
for (size_t i = 0; i < componenets.size(); i++) {
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetDouble());
}
}
}
}
// clang-format off
INSTANTIATE_TEST_SUITE_P(TestCase, DoubleVectorInstructionFoldingTest,
::testing::Values(
// Test case 0: bit-cast int {0x3FF00000,0x00000000,0xC05FD666,0x66666666}
// to double vector
InstructionFoldingCase<std::vector<double>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %v2double %v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {1.0,-127.35})
));
using FloatVectorInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
@ -843,7 +940,24 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {1.6f,1.5f})
2, {1.6f,1.5f}),
// Test case 1: bit-cast unsigned int vector {0x3f800000, 0xbf800000} to
// float vector
InstructionFoldingCase<std::vector<float>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %v2float %v2uint_0x3f800000_0xbf800000\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {1.0f,-1.0f}),
// Test case 2: bit-cast long int 0xbf8000003f800000 to float vector
InstructionFoldingCase<std::vector<float>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %v2float %long_0xbf8000003f800000\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {1.0f,-1.0f})
));
// clang-format on
using BooleanInstructionFoldingTest =