fix the onnxprocess for the empty input and name (#104)
* fixing the onnxprocess for the empty input and name * fix the crash on onnxruntime 1.8
This commit is contained in:
Родитель
0851eacfeb
Коммит
88a3c0e42d
|
@ -11,7 +11,7 @@ endif()
|
|||
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
|
||||
set(CPACK_PACKAGE_VERSION_MAJOR "0")
|
||||
set(CPACK_PACKAGE_VERSION_MINOR "3")
|
||||
set(CPACK_PACKAGE_VERSION_PATCH "0")
|
||||
set(CPACK_PACKAGE_VERSION_PATCH "1")
|
||||
set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH})
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <string.h>
|
||||
|
||||
// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
|
||||
#define ORT_API_VERSION 5
|
||||
#define ORT_API_VERSION 6
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
The entry point to onnxruntime custom op library
|
||||
"""
|
||||
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "0.3.1"
|
||||
__author__ = "Microsoft"
|
||||
|
||||
|
||||
|
|
|
@ -35,18 +35,21 @@ class ONNXModelUtils:
|
|||
|
||||
@classmethod
|
||||
def _rename_graph(cls, graph, prefix, graph_or_container):
|
||||
def io_rename(node, prefix_name):
|
||||
def io_rename(node, prefix_name, idx):
|
||||
new_node = copy.deepcopy(node)
|
||||
if not node.name:
|
||||
new_node.name = "{}_op{}".format(prefix_name, idx)
|
||||
|
||||
del new_node.input[:]
|
||||
new_node.input.extend("{}_{}".format(prefix_name, nm_) for nm_ in node.input)
|
||||
new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
|
||||
del new_node.output[:]
|
||||
new_node.output.extend("{}_{}".format(prefix_name, nm_) for nm_ in node.output)
|
||||
new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
|
||||
return new_node
|
||||
|
||||
assert prefix is not None, 'The graph prefix could not be None'
|
||||
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
|
||||
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
|
||||
return list(io_rename(nd_, prefix) for nd_ in graph.node)
|
||||
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
|
||||
|
||||
@classmethod
|
||||
def _process_node_body(cls, node, prefix):
|
||||
|
@ -97,6 +100,8 @@ class ONNXModelUtils:
|
|||
edges = {}
|
||||
for op in nodes:
|
||||
for x in op.input:
|
||||
if x == '':
|
||||
continue
|
||||
try:
|
||||
predecessor = op_output_map[x]
|
||||
except KeyError:
|
||||
|
@ -125,6 +130,7 @@ class ONNXModelUtils:
|
|||
|
||||
unfinished_nodes.add(node.name)
|
||||
if node.name in edges: # if the node's output is not in the Graph output.
|
||||
assert node.name != '', 'this topological-sort depends on the unique node name.'
|
||||
for successor in edges[node.name]:
|
||||
recursive_helper(successor)
|
||||
|
||||
|
@ -152,7 +158,7 @@ class ONNXModelUtils:
|
|||
iz_set = set(iz_.name for iz_ in container.initializer)
|
||||
for op in ops:
|
||||
iz_needed.update(it_ for it_ in op.input if it_ in iz_set)
|
||||
all_inputs.extend(it_ for it_ in op.input if it_ not in iz_set)
|
||||
all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set)
|
||||
all_outputs.extend(ot_ for ot_ in op.output)
|
||||
|
||||
intersections = set(all_inputs).intersection(set(all_outputs))
|
||||
|
@ -205,12 +211,16 @@ class ONNXTraceSession:
|
|||
else torch.tensor(x) for x in np_inputs]
|
||||
itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor)
|
||||
else tensor_from_onnx(i_, None, None) for i_ in np_inputs]
|
||||
if names is not None:
|
||||
if len(inputs) != len(names):
|
||||
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
|
||||
num = min(len(itensors), len(names))
|
||||
for idx_ in range(num):
|
||||
itensors[idx_].name = names[idx_]
|
||||
if names is None:
|
||||
names = []
|
||||
if len(inputs) != len(names):
|
||||
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
|
||||
names.extend([''] * (len(inputs) - len(names)))
|
||||
for idx_ in range(len(inputs)):
|
||||
names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_)
|
||||
num = min(len(itensors), len(names))
|
||||
for idx_ in range(num):
|
||||
itensors[idx_].name = names[idx_]
|
||||
self.inputs = itensors
|
||||
return self
|
||||
|
||||
|
|
|
@ -45,14 +45,16 @@ struct PyCustomOpKernel {
|
|||
for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
|
||||
size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
|
||||
if (api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
|
||||
if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
|
||||
std::string error_message(api_.GetErrorMessage(status));
|
||||
api_.ReleaseStatus(status);
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to find attribute '", *it, "' due to '",
|
||||
error_message, "'."));
|
||||
}
|
||||
api_.ReleaseStatus(status);
|
||||
if (status != nullptr) {
|
||||
api_.ReleaseStatus(status);
|
||||
}
|
||||
attrs_values_[*it] = "";
|
||||
attrs_values_[*it].resize(size);
|
||||
status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
|
||||
|
@ -63,7 +65,9 @@ struct PyCustomOpKernel {
|
|||
api_.GetErrorMessage(status), "'."));
|
||||
}
|
||||
attrs_values_[*it].resize(size - 1);
|
||||
api_.ReleaseStatus(status);
|
||||
if (status != nullptr) {
|
||||
api_.ReleaseStatus(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче