[VitisAI] 1. KernelDef supports StartVersion and EndVersion (#21519)
### Description <!-- Describe your changes. --> [VitisAI] 1. KernelDef supports StartVersion and EndVersion 2. CapabilityOps checks domain ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Zhenze Wang <zhenzew@xilinx.com>
This commit is contained in:
Родитель
5af423c7c0
Коммит
690d745cbf
|
@ -51,7 +51,11 @@ GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph,
|
|||
|
||||
std::vector<NodeIndex> node_indexs = graph.GetNodesInTopologicalOrder();
|
||||
node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end());
|
||||
node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end());
|
||||
node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(),
|
||||
[&](NodeIndex index) {
|
||||
auto node = graph.GetNode(index);
|
||||
return all_support_optypes_by_eps.count(node->Domain() + ":" + node->OpType()) == 0; }),
|
||||
node_indexs.end());
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
for (auto& n : node_indexs) {
|
||||
|
|
|
@ -173,7 +173,7 @@ void create_kernel_registry(std::vector<OrtCustomOpDomain*> domains) {
|
|||
auto def_builder = KernelDefBuilder::Create();
|
||||
def_builder->SetName(op->GetName(op));
|
||||
def_builder->SetDomain(domain->domain_.c_str());
|
||||
def_builder->SinceVersion(1);
|
||||
def_builder->SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op));
|
||||
if (op->version > 12) {
|
||||
auto input_count = op->GetInputTypeCount(op);
|
||||
for (auto i = 0u; i < input_count; i++) {
|
||||
|
@ -183,7 +183,7 @@ void create_kernel_registry(std::vector<OrtCustomOpDomain*> domains) {
|
|||
def_builder->Provider(onnxruntime::kVitisAIExecutionProvider);
|
||||
KernelCreateFn kernel_create_fn =
|
||||
[op](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
|
||||
// out = std::make_unique<MyCustomOpKernel>(info, *op);
|
||||
out = std::make_unique<MyCustomOpKernel>(info, *op);
|
||||
return Status::OK();
|
||||
};
|
||||
std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn));
|
||||
|
|
|
@ -44,7 +44,7 @@ VitisAIExecutionProvider::VitisAIExecutionProvider(
|
|||
void VitisAIExecutionProvider::CreateKernelRegistry() {
|
||||
for (const auto& domain : get_domains_vitisaiep()) {
|
||||
for (const auto* op : domain->custom_ops_) {
|
||||
vitisai_optypes_.insert(op->GetName(op));
|
||||
vitisai_optypes_.insert(domain->domain_ + ":" + op->GetName(op));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче