Support folding OpBitcast with numeric constants (#4247)
Add constant folding rule for OpBitcast with numeric scalar or vector constants.
This commit is contained in:
Родитель
6cdf07d2b3
Коммит
07ec4f83c5
|
@ -389,6 +389,36 @@ const Constant* ConstantManager::GetConstant(
|
|||
return cst ? RegisterConstant(std::move(cst)) : nullptr;
|
||||
}
|
||||
|
||||
const Constant* ConstantManager::GetNumericVectorConstantWithWords(
|
||||
const Vector* type, const std::vector<uint32_t>& literal_words) {
|
||||
const auto* element_type = type->element_type();
|
||||
uint32_t words_per_element = 0;
|
||||
if (const auto* float_type = element_type->AsFloat())
|
||||
words_per_element = float_type->width() / 32;
|
||||
else if (const auto* int_type = element_type->AsInteger())
|
||||
words_per_element = int_type->width() / 32;
|
||||
|
||||
if (words_per_element != 1 && words_per_element != 2) return nullptr;
|
||||
|
||||
if (words_per_element * type->element_count() !=
|
||||
static_cast<uint32_t>(literal_words.size())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> element_ids;
|
||||
for (uint32_t i = 0; i < type->element_count(); ++i) {
|
||||
auto first_word = literal_words.begin() + (words_per_element * i);
|
||||
std::vector<uint32_t> const_data(first_word,
|
||||
first_word + words_per_element);
|
||||
const analysis::Constant* element_constant =
|
||||
GetConstant(element_type, const_data);
|
||||
auto element_id = GetDefiningInstruction(element_constant)->result_id();
|
||||
element_ids.push_back(element_id);
|
||||
}
|
||||
|
||||
return GetConstant(type, element_ids);
|
||||
}
|
||||
|
||||
uint32_t ConstantManager::GetFloatConst(float val) {
|
||||
Type* float_type = context()->get_type_mgr()->GetFloatType();
|
||||
utils::FloatProxy<float> v(val);
|
||||
|
|
|
@ -506,10 +506,11 @@ class ConstantManager {
|
|||
IRContext* context() const { return ctx_; }
|
||||
|
||||
// Gets or creates a unique Constant instance of type |type| and a vector of
|
||||
// constant defining words |words|. If a Constant instance existed already in
|
||||
// the constant pool, it returns a pointer to it. Otherwise, it creates one
|
||||
// using CreateConstant. If a new Constant instance cannot be created, it
|
||||
// returns nullptr.
|
||||
// constant defining words or ids for elements of Vector type
|
||||
// |literal_words_or_ids|. If a Constant instance existed already in the
|
||||
// constant pool, it returns a pointer to it. Otherwise, it creates one using
|
||||
// CreateConstant. If a new Constant instance cannot be created, it returns
|
||||
// nullptr.
|
||||
const Constant* GetConstant(
|
||||
const Type* type, const std::vector<uint32_t>& literal_words_or_ids);
|
||||
|
||||
|
@ -519,6 +520,14 @@ class ConstantManager {
|
|||
literal_words_or_ids.end()));
|
||||
}
|
||||
|
||||
// Gets or creates a unique Constant instance of Vector type |type| with
|
||||
// numeric elements and a vector of constant defining words |literal_words|.
|
||||
// If a Constant instance existed already in the constant pool, it returns a
|
||||
// pointer to it. Otherwise, it creates one using CreateConstant. If a new
|
||||
// Constant instance cannot be created, it returns nullptr.
|
||||
const Constant* GetNumericVectorConstantWithWords(
|
||||
const Vector* type, const std::vector<uint32_t>& literal_words);
|
||||
|
||||
// Gets or creates a Constant instance to hold the constant value of the given
|
||||
// instruction. It returns a pointer to a Constant instance or nullptr if it
|
||||
// could not create the constant.
|
||||
|
|
|
@ -124,6 +124,66 @@ Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
|
|||
inst->GetSingleWordInOperand(in_op));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ExtractInts(uint64_t val) {
|
||||
std::vector<uint32_t> words;
|
||||
words.push_back(static_cast<uint32_t>(val));
|
||||
words.push_back(static_cast<uint32_t>(val >> 32));
|
||||
return words;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> GetWordsFromScalarIntConstant(
|
||||
const analysis::IntConstant* c) {
|
||||
assert(c != nullptr);
|
||||
uint32_t width = c->type()->AsInteger()->width();
|
||||
assert(width == 32 || width == 64);
|
||||
if (width == 64) {
|
||||
uint64_t uval = static_cast<uint64_t>(c->GetU64());
|
||||
return ExtractInts(uval);
|
||||
}
|
||||
return {c->GetU32()};
|
||||
}
|
||||
|
||||
std::vector<uint32_t> GetWordsFromScalarFloatConstant(
|
||||
const analysis::FloatConstant* c) {
|
||||
assert(c != nullptr);
|
||||
uint32_t width = c->type()->AsFloat()->width();
|
||||
assert(width == 32 || width == 64);
|
||||
if (width == 64) {
|
||||
utils::FloatProxy<double> result(c->GetDouble());
|
||||
return result.GetWords();
|
||||
}
|
||||
utils::FloatProxy<float> result(c->GetFloat());
|
||||
return result.GetWords();
|
||||
}
|
||||
|
||||
std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
|
||||
analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
|
||||
if (const auto* float_constant = c->AsFloatConstant()) {
|
||||
return GetWordsFromScalarFloatConstant(float_constant);
|
||||
} else if (const auto* int_constant = c->AsIntConstant()) {
|
||||
return GetWordsFromScalarIntConstant(int_constant);
|
||||
} else if (const auto* vec_constant = c->AsVectorConstant()) {
|
||||
std::vector<uint32_t> words;
|
||||
for (const auto* comp : vec_constant->GetComponents()) {
|
||||
auto comp_in_words =
|
||||
GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
|
||||
words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
|
||||
}
|
||||
return words;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
|
||||
analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
|
||||
const analysis::Type* type) {
|
||||
if (type->AsInteger() || type->AsFloat())
|
||||
return const_mgr->GetConstant(type, words);
|
||||
if (const auto* vec_type = type->AsVector())
|
||||
return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
|
||||
// constant.
|
||||
uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
|
||||
|
@ -146,13 +206,6 @@ uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
|
|||
return const_mgr->GetDefiningInstruction(negated_const)->result_id();
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ExtractInts(uint64_t val) {
|
||||
std::vector<uint32_t> words;
|
||||
words.push_back(static_cast<uint32_t>(val));
|
||||
words.push_back(static_cast<uint32_t>(val >> 32));
|
||||
return words;
|
||||
}
|
||||
|
||||
// Negates the integer constant |c|. Returns the id of the defining instruction.
|
||||
uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
|
||||
const analysis::Constant* c) {
|
||||
|
@ -1796,6 +1849,33 @@ FoldingRule RedundantPhi() {
|
|||
};
|
||||
}
|
||||
|
||||
FoldingRule BitCastScalarOrVector() {
|
||||
return [](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants) {
|
||||
assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
|
||||
if (constants[0] == nullptr) return false;
|
||||
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
|
||||
return false;
|
||||
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
std::vector<uint32_t> words =
|
||||
GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
|
||||
if (words.size() == 0) return false;
|
||||
|
||||
const analysis::Constant* bitcasted_constant =
|
||||
ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
|
||||
auto new_feeder_id =
|
||||
const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
|
||||
->result_id();
|
||||
inst->SetOpcode(SpvOpCopyObject);
|
||||
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
FoldingRule RedundantSelect() {
|
||||
// An OpSelect instruction where both values are the same or the condition is
|
||||
// constant can be replaced by one of the values
|
||||
|
@ -2423,6 +2503,8 @@ void FoldingRules::AddFoldingRules() {
|
|||
// Note that the order in which rules are added to the list matters. If a rule
|
||||
// applies to the instruction, the rest of the rules will not be attempted.
|
||||
// Take that into consideration.
|
||||
rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
|
||||
|
||||
rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
|
||||
|
||||
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
|
||||
|
|
|
@ -141,6 +141,7 @@ OpName %main "main"
|
|||
%v4int = OpTypeVector %int 4
|
||||
%v4float = OpTypeVector %float 4
|
||||
%v4double = OpTypeVector %double 4
|
||||
%v2uint = OpTypeVector %uint 2
|
||||
%v2float = OpTypeVector %float 2
|
||||
%v2double = OpTypeVector %double 2
|
||||
%v2half = OpTypeVector %half 2
|
||||
|
@ -191,6 +192,7 @@ OpName %main "main"
|
|||
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
|
||||
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
|
||||
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
|
||||
%v2int_min_max = OpConstantComposite %v2int %int_min %int_max
|
||||
%v2bool_null = OpConstantNull %v2bool
|
||||
%v2bool_true_false = OpConstantComposite %v2bool %true %false
|
||||
%v2bool_false_true = OpConstantComposite %v2bool %false %true
|
||||
|
@ -258,6 +260,15 @@ OpName %main "main"
|
|||
%v4double_1_1_1_0p5 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_0p5
|
||||
%v4double_null = OpConstantNull %v4double
|
||||
%v4float_n1_2_1_3 = OpConstantComposite %v4float %float_n1 %float_2 %float_1 %float_3
|
||||
%uint_0x3f800000 = OpConstant %uint 0x3f800000
|
||||
%uint_0xbf800000 = OpConstant %uint 0xbf800000
|
||||
%v2uint_0x3f800000_0xbf800000 = OpConstantComposite %v2uint %uint_0x3f800000 %uint_0xbf800000
|
||||
%long_0xbf8000003f800000 = OpConstant %long 0xbf8000003f800000
|
||||
%int_0x3FF00000 = OpConstant %int 0x3FF00000
|
||||
%int_0x00000000 = OpConstant %int 0x00000000
|
||||
%int_0xC05FD666 = OpConstant %int 0xC05FD666
|
||||
%int_0x66666666 = OpConstant %int 0x66666666
|
||||
%v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666 = OpConstantComposite %v4int %int_0x00000000 %int_0x3FF00000 %int_0x66666666 %int_0xC05FD666
|
||||
)";
|
||||
|
||||
return header;
|
||||
|
@ -708,7 +719,31 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
|
|||
"%2 = OpExtInst %uint %1 UClamp %uint_2 %undef %uint_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 1)
|
||||
2, 1),
|
||||
// Test case 46: Bit-cast int 0 to unsigned int
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpBitcast %uint %int_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 47: Bit-cast int -24 to unsigned int
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpBitcast %uint %int_n24\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, static_cast<uint32_t>(-24)),
|
||||
// Test case 48: Bit-cast float 1.0f to unsigned int
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpBitcast %uint %float_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, static_cast<uint32_t>(0x3f800000))
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
|
@ -790,10 +825,72 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntVectorInstructionFoldingTest,
|
|||
"%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 4294967295 \n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {0,0})
|
||||
2, {0,0}),
|
||||
// Test case 4: fold bit-cast int -24 to unsigned int
|
||||
InstructionFoldingCase<std::vector<uint32_t>>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_int Function\n" +
|
||||
"%load = OpLoad %int %n\n" +
|
||||
"%2 = OpBitcast %v2uint %v2int_min_max\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {2147483648, 2147483647})
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
using DoubleVectorInstructionFoldingTest =
|
||||
::testing::TestWithParam<InstructionFoldingCase<std::vector<double>>>;
|
||||
|
||||
TEST_P(DoubleVectorInstructionFoldingTest, Case) {
|
||||
const auto& tc = GetParam();
|
||||
|
||||
// Build module.
|
||||
std::unique_ptr<IRContext> context =
|
||||
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
|
||||
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||
ASSERT_NE(nullptr, context);
|
||||
|
||||
// Fold the instruction to test.
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
|
||||
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
|
||||
|
||||
// Make sure the instruction folded as expected.
|
||||
EXPECT_TRUE(succeeded);
|
||||
if (succeeded && inst != nullptr) {
|
||||
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
|
||||
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
|
||||
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
|
||||
EXPECT_THAT(opcodes, Contains(inst->opcode()));
|
||||
analysis::ConstantManager* const_mrg = context->get_constant_mgr();
|
||||
const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
|
||||
EXPECT_NE(result, nullptr);
|
||||
if (result != nullptr) {
|
||||
const std::vector<const analysis::Constant*>& componenets =
|
||||
result->AsVectorConstant()->GetComponents();
|
||||
EXPECT_EQ(componenets.size(), tc.expected_result.size());
|
||||
for (size_t i = 0; i < componenets.size(); i++) {
|
||||
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetDouble());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
INSTANTIATE_TEST_SUITE_P(TestCase, DoubleVectorInstructionFoldingTest,
|
||||
::testing::Values(
|
||||
// Test case 0: bit-cast int {0x3FF00000,0x00000000,0xC05FD666,0x66666666}
|
||||
// to double vector
|
||||
InstructionFoldingCase<std::vector<double>>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpBitcast %v2double %v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {1.0,-127.35})
|
||||
));
|
||||
|
||||
using FloatVectorInstructionFoldingTest =
|
||||
::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
|
||||
|
||||
|
@ -843,7 +940,24 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
|
|||
"%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {1.6f,1.5f})
|
||||
2, {1.6f,1.5f}),
|
||||
// Test case 1: bit-cast unsigned int vector {0x3f800000, 0xbf800000} to
|
||||
// float vector
|
||||
InstructionFoldingCase<std::vector<float>>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpBitcast %v2float %v2uint_0x3f800000_0xbf800000\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {1.0f,-1.0f}),
|
||||
// Test case 2: bit-cast long int 0xbf8000003f800000 to float vector
|
||||
InstructionFoldingCase<std::vector<float>>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpBitcast %v2float %long_0xbf8000003f800000\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {1.0f,-1.0f})
|
||||
));
|
||||
// clang-format on
|
||||
using BooleanInstructionFoldingTest =
|
||||
|
|
Загрузка…
Ссылка в новой задаче