Enable OpTypeCooperativeMatrix specialization (#2927)
This commit is contained in:
Родитель
c18c9ff6bc
Коммит
3c7ff8d4f0
|
@ -409,6 +409,22 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
|
|||
{static_cast<uint32_t>(
|
||||
type->AsForwardPointer()->storage_class())}}});
|
||||
break;
|
||||
case Type::kCooperativeMatrixNV: {
|
||||
auto coop_mat = type->AsCooperativeMatrixNV();
|
||||
uint32_t const component_type =
|
||||
GetTypeInstruction(coop_mat->component_type());
|
||||
if (component_type == 0) {
|
||||
return 0;
|
||||
}
|
||||
typeInst = MakeUnique<Instruction>(
|
||||
context(), SpvOpTypeCooperativeMatrixNV, 0, id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {component_type}},
|
||||
{SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Unexpected type");
|
||||
break;
|
||||
|
@ -604,6 +620,14 @@ Type* TypeManager::RebuildType(const Type& type) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case Type::kCooperativeMatrixNV: {
|
||||
const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV();
|
||||
const Type* component_type = cm_type->component_type();
|
||||
rebuilt_ty = MakeUnique<CooperativeMatrixNV>(
|
||||
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
|
||||
cm_type->columns_id());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
return nullptr;
|
||||
|
@ -832,6 +856,12 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
|
|||
case SpvOpTypeAccelerationStructureNV:
|
||||
type = new AccelerationStructureNV();
|
||||
break;
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
type = new CooperativeMatrixNV(GetType(inst.GetSingleWordInOperand(0)),
|
||||
inst.GetSingleWordInOperand(1),
|
||||
inst.GetSingleWordInOperand(2),
|
||||
inst.GetSingleWordInOperand(3));
|
||||
break;
|
||||
default:
|
||||
SPIRV_UNIMPLEMENTED(consumer_, "unhandled type");
|
||||
break;
|
||||
|
|
|
@ -127,6 +127,7 @@ std::unique_ptr<Type> Type::Clone() const {
|
|||
DeclareKindCase(PipeStorage);
|
||||
DeclareKindCase(NamedBarrier);
|
||||
DeclareKindCase(AccelerationStructureNV);
|
||||
DeclareKindCase(CooperativeMatrixNV);
|
||||
#undef DeclareKindCase
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
|
@ -171,6 +172,7 @@ bool Type::operator==(const Type& other) const {
|
|||
DeclareKindCase(PipeStorage);
|
||||
DeclareKindCase(NamedBarrier);
|
||||
DeclareKindCase(AccelerationStructureNV);
|
||||
DeclareKindCase(CooperativeMatrixNV);
|
||||
#undef DeclareKindCase
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
|
@ -220,6 +222,7 @@ void Type::GetHashWords(std::vector<uint32_t>* words,
|
|||
DeclareKindCase(PipeStorage);
|
||||
DeclareKindCase(NamedBarrier);
|
||||
DeclareKindCase(AccelerationStructureNV);
|
||||
DeclareKindCase(CooperativeMatrixNV);
|
||||
#undef DeclareKindCase
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
|
@ -654,6 +657,44 @@ void ForwardPointer::GetExtraHashWords(
|
|||
if (pointer_) pointer_->GetHashWords(words, seen);
|
||||
}
|
||||
|
||||
CooperativeMatrixNV::CooperativeMatrixNV(const Type* type, const uint32_t scope,
|
||||
const uint32_t rows,
|
||||
const uint32_t columns)
|
||||
: Type(kCooperativeMatrixNV),
|
||||
component_type_(type),
|
||||
scope_id_(scope),
|
||||
rows_id_(rows),
|
||||
columns_id_(columns) {
|
||||
assert(type != nullptr);
|
||||
assert(scope != 0);
|
||||
assert(rows != 0);
|
||||
assert(columns != 0);
|
||||
}
|
||||
|
||||
std::string CooperativeMatrixNV::str() const {
|
||||
std::ostringstream oss;
|
||||
oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
|
||||
<< ", " << columns_id_ << ">";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void CooperativeMatrixNV::GetExtraHashWords(
|
||||
std::vector<uint32_t>* words, std::unordered_set<const Type*>* pSet) const {
|
||||
component_type_->GetHashWords(words, pSet);
|
||||
words->push_back(scope_id_);
|
||||
words->push_back(rows_id_);
|
||||
words->push_back(columns_id_);
|
||||
}
|
||||
|
||||
bool CooperativeMatrixNV::IsSameImpl(const Type* that,
|
||||
IsSameCache* seen) const {
|
||||
const CooperativeMatrixNV* mt = that->AsCooperativeMatrixNV();
|
||||
if (!mt) return false;
|
||||
return component_type_->IsSameImpl(mt->component_type_, seen) &&
|
||||
scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
|
||||
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
|
|
@ -58,6 +58,7 @@ class ForwardPointer;
|
|||
class PipeStorage;
|
||||
class NamedBarrier;
|
||||
class AccelerationStructureNV;
|
||||
class CooperativeMatrixNV;
|
||||
|
||||
// Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
|
||||
// which is used as a way to probe the actual <subclass>.
|
||||
|
@ -93,6 +94,7 @@ class Type {
|
|||
kPipeStorage,
|
||||
kNamedBarrier,
|
||||
kAccelerationStructureNV,
|
||||
kCooperativeMatrixNV
|
||||
};
|
||||
|
||||
Type(Kind k) : kind_(k) {}
|
||||
|
@ -196,6 +198,7 @@ class Type {
|
|||
DeclareCastMethod(PipeStorage)
|
||||
DeclareCastMethod(NamedBarrier)
|
||||
DeclareCastMethod(AccelerationStructureNV)
|
||||
DeclareCastMethod(CooperativeMatrixNV)
|
||||
#undef DeclareCastMethod
|
||||
|
||||
protected:
|
||||
|
@ -597,6 +600,36 @@ class ForwardPointer : public Type {
|
|||
const Pointer* pointer_;
|
||||
};
|
||||
|
||||
class CooperativeMatrixNV : public Type {
|
||||
public:
|
||||
CooperativeMatrixNV(const Type* type, const uint32_t scope,
|
||||
const uint32_t rows, const uint32_t columns);
|
||||
CooperativeMatrixNV(const CooperativeMatrixNV&) = default;
|
||||
|
||||
std::string str() const override;
|
||||
|
||||
CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; }
|
||||
const CooperativeMatrixNV* AsCooperativeMatrixNV() const override {
|
||||
return this;
|
||||
}
|
||||
|
||||
void GetExtraHashWords(std::vector<uint32_t>*,
|
||||
std::unordered_set<const Type*>*) const override;
|
||||
|
||||
const Type* component_type() const { return component_type_; }
|
||||
uint32_t scope_id() const { return scope_id_; }
|
||||
uint32_t rows_id() const { return rows_id_; }
|
||||
uint32_t columns_id() const { return columns_id_; }
|
||||
|
||||
private:
|
||||
bool IsSameImpl(const Type* that, IsSameCache*) const override;
|
||||
|
||||
const Type* component_type_;
|
||||
const uint32_t scope_id_;
|
||||
const uint32_t rows_id_;
|
||||
const uint32_t columns_id_;
|
||||
};
|
||||
|
||||
#define DefineParameterlessType(type, name) \
|
||||
class type : public Type { \
|
||||
public: \
|
||||
|
|
|
@ -156,7 +156,8 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
|
|||
types.emplace_back(new ReserveId());
|
||||
types.emplace_back(new Queue());
|
||||
|
||||
// Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV
|
||||
// Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV,
|
||||
// CooperativeMatrixNV
|
||||
types.emplace_back(new Pipe(SpvAccessQualifierReadWrite));
|
||||
types.emplace_back(new Pipe(SpvAccessQualifierReadOnly));
|
||||
types.emplace_back(new ForwardPointer(1, SpvStorageClassInput));
|
||||
|
@ -165,6 +166,7 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
|
|||
types.emplace_back(new PipeStorage());
|
||||
types.emplace_back(new NamedBarrier());
|
||||
types.emplace_back(new AccelerationStructureNV());
|
||||
types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
|
||||
|
||||
return types;
|
||||
}
|
||||
|
@ -214,6 +216,7 @@ TEST(TypeManager, TypeStrings) {
|
|||
%arr_spec_const_with_id = OpTypeArray %s32 %spec_const_with_id
|
||||
%arr_long_constant = OpTypeArray %s32 %long_constant
|
||||
%arr_spec_const_op = OpTypeArray %s32 %spec_const_op
|
||||
%cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
|
||||
)";
|
||||
|
||||
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
|
||||
|
@ -251,6 +254,7 @@ TEST(TypeManager, TypeStrings) {
|
|||
{36, "[sint32, id(1), words(1,99,42)]"},
|
||||
{37, "[sint32, id(33), words(0,705032704,1)]"},
|
||||
{38, "[sint32, id(34), words(2,34)]"},
|
||||
{39, "<float64, 6, 6, 6>"},
|
||||
};
|
||||
|
||||
std::unique_ptr<IRContext> context =
|
||||
|
@ -1060,6 +1064,7 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
|
|||
; CHECK: OpTypePipeStorage
|
||||
; CHECK: OpTypeNamedBarrier
|
||||
; CHECK: OpTypeAccelerationStructureNV
|
||||
; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
|
||||
OpCapability Shader
|
||||
OpCapability Int64
|
||||
OpCapability Linkage
|
||||
|
|
Загрузка…
Ссылка в новой задаче