fix the onnxprocess for the empty input and name (#104)

* fixing the onnxprocess for the empty input and name

* fix the crash on onnxruntime 1.8
This commit is contained in:
Wenbing Li 2021-06-03 21:23:13 -07:00 коммит произвёл GitHub
Родитель 0851eacfeb
Коммит 88a3c0e42d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 33 добавлений и 19 удалений

Просмотреть файл

@ -11,7 +11,7 @@ endif()
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
set(CPACK_PACKAGE_VERSION_MAJOR "0")
set(CPACK_PACKAGE_VERSION_MINOR "3")
set(CPACK_PACKAGE_VERSION_PATCH "0")
set(CPACK_PACKAGE_VERSION_PATCH "1")
set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH})

Просмотреть файл

@ -7,7 +7,7 @@
#include <string.h>
// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
#define ORT_API_VERSION 5
#define ORT_API_VERSION 6
#ifdef __cplusplus
extern "C" {

Просмотреть файл

@ -7,7 +7,7 @@
The entry point to onnxruntime custom op library
"""
__version__ = "0.3.0"
__version__ = "0.3.1"
__author__ = "Microsoft"

Просмотреть файл

@ -35,18 +35,21 @@ class ONNXModelUtils:
@classmethod
def _rename_graph(cls, graph, prefix, graph_or_container):
def io_rename(node, prefix_name):
def io_rename(node, prefix_name, idx):
new_node = copy.deepcopy(node)
if not node.name:
new_node.name = "{}_op{}".format(prefix_name, idx)
del new_node.input[:]
new_node.input.extend("{}_{}".format(prefix_name, nm_) for nm_ in node.input)
new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
del new_node.output[:]
new_node.output.extend("{}_{}".format(prefix_name, nm_) for nm_ in node.output)
new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
return new_node
assert prefix is not None, 'The graph prefix could not be None'
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
return list(io_rename(nd_, prefix) for nd_ in graph.node)
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
@classmethod
def _process_node_body(cls, node, prefix):
@ -97,6 +100,8 @@ class ONNXModelUtils:
edges = {}
for op in nodes:
for x in op.input:
if x == '':
continue
try:
predecessor = op_output_map[x]
except KeyError:
@ -125,6 +130,7 @@ class ONNXModelUtils:
unfinished_nodes.add(node.name)
if node.name in edges: # if the node's output is not in the Graph output.
assert node.name != '', 'this topological-sort depends on the unique node name.'
for successor in edges[node.name]:
recursive_helper(successor)
@ -152,7 +158,7 @@ class ONNXModelUtils:
iz_set = set(iz_.name for iz_ in container.initializer)
for op in ops:
iz_needed.update(it_ for it_ in op.input if it_ in iz_set)
all_inputs.extend(it_ for it_ in op.input if it_ not in iz_set)
all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set)
all_outputs.extend(ot_ for ot_ in op.output)
intersections = set(all_inputs).intersection(set(all_outputs))
@ -205,12 +211,16 @@ class ONNXTraceSession:
else torch.tensor(x) for x in np_inputs]
itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor)
else tensor_from_onnx(i_, None, None) for i_ in np_inputs]
if names is not None:
if len(inputs) != len(names):
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
num = min(len(itensors), len(names))
for idx_ in range(num):
itensors[idx_].name = names[idx_]
if names is None:
names = []
if len(inputs) != len(names):
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
names.extend([''] * (len(inputs) - len(names)))
for idx_ in range(len(inputs)):
names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_)
num = min(len(itensors), len(names))
for idx_ in range(num):
itensors[idx_].name = names[idx_]
self.inputs = itensors
return self

Просмотреть файл

@ -45,14 +45,16 @@ struct PyCustomOpKernel {
for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
if (api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
std::string error_message(api_.GetErrorMessage(status));
api_.ReleaseStatus(status);
throw std::runtime_error(MakeString(
"Unable to find attribute '", *it, "' due to '",
error_message, "'."));
}
api_.ReleaseStatus(status);
if (status != nullptr) {
api_.ReleaseStatus(status);
}
attrs_values_[*it] = "";
attrs_values_[*it].resize(size);
status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
@ -63,7 +65,9 @@ struct PyCustomOpKernel {
api_.GetErrorMessage(status), "'."));
}
attrs_values_[*it].resize(size - 1);
api_.ReleaseStatus(status);
if (status != nullptr) {
api_.ReleaseStatus(status);
}
}
}