Support ONNX Scan op
This commit is contained in:
Родитель
a55e871ec8
Коммит
e940605f6b
|
@ -178,6 +178,7 @@
|
|||
<ClInclude Include="Logger.h" />
|
||||
<ClInclude Include="MinibatchSource.h" />
|
||||
<ClInclude Include="proto\onnx\CNTKToONNX.h" />
|
||||
<ClInclude Include="proto\onnx\ControlFlowHelper.h" />
|
||||
<ClInclude Include="proto\onnx\core\common\profiler.h" />
|
||||
<ClInclude Include="proto\onnx\core\common\task_thread_pool.h" />
|
||||
<ClInclude Include="proto\onnx\core\framework\tensorutils.h" />
|
||||
|
|
|
@ -391,6 +391,9 @@
|
|||
<ClInclude Include="proto\onnx\core\platform\windows\debug_alloc.h">
|
||||
<Filter>proto\onnx\core\platform\windows</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\ControlFlowHelper.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
@ -509,11 +512,5 @@
|
|||
<Proto Include="tensorboard\tensorboard.proto">
|
||||
<Filter>tensorboard</Filter>
|
||||
</Proto>
|
||||
<Proto Include="proto\onnx\onnx_repo\onnx\onnx-ml.proto">
|
||||
<Filter>proto\onnx\onnx_repo\onnx</Filter>
|
||||
</Proto>
|
||||
<Proto Include="proto\onnx\onnx_repo\onnx\onnx-operators-ml.proto">
|
||||
<Filter>proto\onnx\onnx_repo\onnx</Filter>
|
||||
</Proto>
|
||||
</ItemGroup>
|
||||
</Project>
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,208 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Internals/ComputationGraphAlgorithms.h"
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class ScanLoopState
|
||||
{
|
||||
public:
|
||||
ScanLoopState(const Variable initialState, onnxruntime::NodeArg *initialStateNodeArg, const Variable stateOutput, int delay) :
|
||||
m_initialState(initialState),
|
||||
m_initialStateNodeArg(initialStateNodeArg),
|
||||
m_stateOutput(stateOutput),
|
||||
m_delay(delay),
|
||||
m_hasInitializer(false)
|
||||
{}
|
||||
|
||||
Variable m_initialState;
|
||||
onnxruntime::NodeArg *m_initialStateNodeArg;
|
||||
Variable m_stateOutput;
|
||||
onnx::TensorProto m_initialStateTensor;
|
||||
bool m_hasInitializer;
|
||||
int m_delay;
|
||||
};
|
||||
|
||||
class ScanLoop
|
||||
{
|
||||
public:
|
||||
ScanLoop(const std::vector<Variable> &inputs, const std::vector<Variable> &outputs,
|
||||
const std::vector<Variable> &scanInputs, const std::vector<Variable> &scanOutputs, const std::vector<FunctionPtr> &body) :
|
||||
m_inputs(inputs),
|
||||
m_outputs(outputs),
|
||||
m_scanInputs(scanInputs),
|
||||
m_scanOutputs(scanOutputs),
|
||||
m_body(body),
|
||||
m_scanOpCreated(false)
|
||||
{}
|
||||
|
||||
std::vector<Variable> m_inputs, m_outputs, m_scanInputs, m_scanOutputs;
|
||||
std::vector<FunctionPtr> m_body;
|
||||
std::vector<ScanLoopState> scanLoopStates;
|
||||
std::vector<FunctionPtr> m_visited;
|
||||
bool m_scanOpCreated;
|
||||
};
|
||||
|
||||
// Implementation of a graph based on ComputationNodes.
|
||||
class CNTKModelGraph : public DirectedGraph<FunctionPtr>
|
||||
{
|
||||
std::vector<FunctionPtr> m_roots;
|
||||
|
||||
public:
|
||||
CNTKModelGraph(const std::vector<FunctionPtr>& roots) : m_roots(roots) {}
|
||||
|
||||
std::vector<FunctionPtr> Predecessors(const FunctionPtr& node) const override
|
||||
{
|
||||
std::vector<FunctionPtr> predecessors;
|
||||
for (auto &input : node->Inputs())
|
||||
{
|
||||
if (input.Owner())
|
||||
predecessors.push_back(input.Owner());
|
||||
}
|
||||
return predecessors;
|
||||
}
|
||||
|
||||
const std::vector<FunctionPtr>& Roots() const override
|
||||
{
|
||||
return m_roots;
|
||||
}
|
||||
};
|
||||
|
||||
std::wstring ToString(FunctionPtr f)
|
||||
{
|
||||
return L"( " + f->Name() + L": " + f->Uid() + L")";
|
||||
}
|
||||
|
||||
void BuildLoops(const std::vector<FunctionPtr>& roots,
|
||||
std::vector<ScanLoop> &scanLoops)
|
||||
{
|
||||
std::vector<StrongComponent<FunctionPtr>> loops;
|
||||
CNTKModelGraph cntkModelGraph(roots);
|
||||
loops = StrongComponents<FunctionPtr>(cntkModelGraph);
|
||||
|
||||
// Sort nodes inside the strong components in the evaluation order.
|
||||
std::function<bool(const FunctionPtr&)> delay
|
||||
= [](const FunctionPtr& f)
|
||||
{
|
||||
if (f->OpName() == L"PastValue")
|
||||
return 1;
|
||||
if (f->OpName() == L"FutureValue")
|
||||
return -1;
|
||||
else
|
||||
return 0;
|
||||
};
|
||||
|
||||
EvaluationSort(cntkModelGraph, delay, loops);
|
||||
|
||||
// Attributes:
|
||||
// body: N+M inputs, N+K outputs, N is the # of states, M inputs, K outputs)
|
||||
// directions(M):
|
||||
// num_scan_inputs(M):
|
||||
//
|
||||
// Inputs:
|
||||
// sequence_length:
|
||||
// max sequence length if not specified
|
||||
// initial_state_and_scan_inputs(N + M):
|
||||
// initial_states are constant attributes from step functions
|
||||
// scan_inputs are input to this loop body with sequence axis
|
||||
//
|
||||
// Outputs:
|
||||
// final_state_and_scan_outputs(N + K):
|
||||
// final_state: ?
|
||||
// scan_outputs are outputs from this loop body with sequence axis
|
||||
|
||||
std::vector<std::vector<Variable>> loopinputs, loopoutputs, scaninputs, scanoutputs;
|
||||
loopinputs.resize(loops.size());
|
||||
loopoutputs.resize(loops.size());
|
||||
scaninputs.resize(loops.size());
|
||||
scanoutputs.resize(loops.size());
|
||||
bool nestedSearchInsideBlockFunction = false;
|
||||
std::vector<FunctionPtr> visited;
|
||||
for (auto &root : roots)
|
||||
{
|
||||
root->PreorderTraverse([&root, &loops, &loopinputs, &loopoutputs, &scaninputs, &scanoutputs, &visited](const FunctionPtr& function) {
|
||||
if (std::find(visited.begin(), visited.end(), function) != visited.end())
|
||||
return;
|
||||
|
||||
for (int l = 0; l < loops.size(); l++)
|
||||
{
|
||||
const StrongComponent<FunctionPtr> &loop = loops[l];
|
||||
std::vector<Variable> &inputs = loopinputs[l];
|
||||
std::vector<Variable> &outputs = loopoutputs[l];
|
||||
const std::vector<FunctionPtr> &nodes = loop.Nodes();
|
||||
if (std::find(nodes.begin(), nodes.end(), function) != nodes.end())
|
||||
{
|
||||
// if a function is part of a loop, any its inputs that are not from the loop body
|
||||
// is an input.
|
||||
for (auto &input : function->Inputs())
|
||||
{
|
||||
if (!input.Owner() || (input.Owner() && std::find(nodes.begin(), nodes.end(), input.Owner()) == nodes.end()))
|
||||
{
|
||||
if (std::find(inputs.begin(), inputs.end(), input) == inputs.end())
|
||||
{
|
||||
inputs.push_back(input);
|
||||
if (input.DynamicAxes().size() == 2)
|
||||
scaninputs[l].push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// if a function is not part of the loop and any of its inputs is from the loop
|
||||
// that input variable is an output from the loop
|
||||
for (auto &input : function->Inputs())
|
||||
{
|
||||
if (input.Owner() && std::find(nodes.begin(), nodes.end(), input.Owner()) != nodes.end())
|
||||
{
|
||||
if (std::find(outputs.begin(), outputs.end(), input) == outputs.end())
|
||||
{
|
||||
outputs.push_back(input);
|
||||
if (input.DynamicAxes().size() == 2)
|
||||
scanoutputs[l].push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, nestedSearchInsideBlockFunction);
|
||||
|
||||
// a corner case: if root src is in the loop body, it shall be an output as well.
|
||||
for (int l = 0; l < loops.size(); l++)
|
||||
{
|
||||
const StrongComponent<FunctionPtr> &loop = loops[l];
|
||||
if (std::find(loop.Nodes().begin(), loop.Nodes().end(), root) != loop.Nodes().end())
|
||||
for (auto output : root->Outputs())
|
||||
if (std::find(scanoutputs[l].begin(), scanoutputs[l].end(), output) == scanoutputs[l].end())
|
||||
scanoutputs[l].push_back(output);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<FunctionPtr>> loopstepfunctions;
|
||||
std::vector<std::vector<Variable>> loopStates;
|
||||
std::vector<bool> filterOutBlockRNNs(loops.size(), false);
|
||||
loopstepfunctions.resize(loops.size());
|
||||
for (int l = 0; l < loops.size(); l++)
|
||||
{
|
||||
const StrongComponent<FunctionPtr> &loop = loops[l];
|
||||
const std::vector<FunctionPtr> &nodes = loop.Nodes();
|
||||
for (auto &f : nodes)
|
||||
{
|
||||
if (f->OpName() == L"PastValue" || f->OpName() == L"FutureValue")
|
||||
loopstepfunctions[l].push_back(f);
|
||||
else if (f->OpName() != L"LSTM" && f->OpName() != L"GRU" && f->OpName() != L"RNNStep")
|
||||
filterOutBlockRNNs[l] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < loops.size(); l++)
|
||||
{
|
||||
if (filterOutBlockRNNs[l])
|
||||
{
|
||||
ScanLoop scanLoop(loopinputs[l], loopoutputs[l], scaninputs[l], scanoutputs[l], loops[l].Nodes());
|
||||
scanLoops.push_back(scanLoop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2170,6 +2170,7 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
vector<pair<Variable, Variable>> argsMap{ pair<Variable, Variable>{operands[0], inputs[0]} };
|
||||
for (int i = 1; i < 5; ++i)
|
||||
{
|
||||
// TODO: this does not work if mean/var inputs are not constant/parameters.
|
||||
argsMap.push_back(pair<Variable, Variable>{ operands[i], inputs[0].GetDataType() == DataType::Float16 ? Utils::ConvertVariableType<float16, float>(inputs[i], true) : inputs[i]});
|
||||
}
|
||||
|
||||
|
|
|
@ -231,6 +231,9 @@ namespace ONNX
|
|||
{ L"StableSigmoid", { {
|
||||
{ L"StableSigmoid", "Sigmoid" },
|
||||
} } },
|
||||
{ L"Sigmoid", { {
|
||||
{ L"Sigmoid", "Sigmoid" },
|
||||
} } },
|
||||
{ L"ElementMax", { {
|
||||
{ L"ElementMax", "Max" },
|
||||
} } },
|
||||
|
@ -362,6 +365,11 @@ namespace ONNX
|
|||
{ L"axis", "axes" },
|
||||
{ L"keepdims", "keepdims" },
|
||||
} } },
|
||||
{ L"Sequence::ReduceElements",{ {
|
||||
{ L"Sequence::ReduceElements", "ReduceSum" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
|
||||
// From tensor
|
||||
{ L"Cast", { {
|
||||
|
@ -517,7 +525,13 @@ namespace ONNX
|
|||
{
|
||||
return opName == "LSTM" || opName == "GRU" || opName == "RNN" || opName == "RNNStep";
|
||||
}
|
||||
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
|
||||
|
||||
bool Operators::IsSequenceBlockOp(const std::string &opName)
|
||||
{
|
||||
return opName == "Sequence::ReduceElements" || opName == "Sequence::BroadcastAs";
|
||||
}
|
||||
|
||||
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
|
||||
{ L"Clip",{ 1, 2 } },
|
||||
{ L"ELU",{ 0, 1 } },
|
||||
{ L"LeakyReLU",{ 0, 1 } },
|
||||
|
|
|
@ -152,6 +152,7 @@ public:
|
|||
|
||||
static bool IsLoopOp(const std::string &opName);
|
||||
static bool IsRNNOp(const std::string &opName);
|
||||
static bool IsSequenceBlockOp(const std::string &opName);
|
||||
|
||||
private:
|
||||
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
||||
|
|
|
@ -703,76 +703,88 @@ void GraphBase::ReverseDFSFrom(const std::vector<const Node*>& from,
|
|||
|
||||
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
|
||||
Status GraphBase::CheckIsAcyclic(std::vector<NodeIndex>& nodes_in_topological_order) const {
|
||||
nodes_in_topological_order.clear();
|
||||
// nodes that have been processed and added to nodes_in_topological_order.
|
||||
std::unordered_set<NodeIndex> visited_nodes;
|
||||
std::unordered_set<NodeIndex> ancestor_nodes;
|
||||
// tracks nodes whose child nodes have been processed.
|
||||
std::unordered_set<NodeIndex> children_visited_nodes;
|
||||
std::stack<NodeIndex> stack;
|
||||
stack.push(sink_node_index_);
|
||||
nodes_in_topological_order.clear();
|
||||
|
||||
while (!stack.empty()) {
|
||||
const NodeIndex current = stack.top();
|
||||
stack.pop();
|
||||
// nodes that have been processed and added to nodes_in_topological_order.
|
||||
std::unordered_set<NodeIndex> processed_nodes;
|
||||
std::unordered_set<NodeIndex> output_nodes;
|
||||
std::unordered_set<NodeIndex> nodes_added_for_processing;
|
||||
std::stack<NodeIndex> stack;
|
||||
|
||||
if (visited_nodes.end() != visited_nodes.find(current)) {
|
||||
// The node has been visited before
|
||||
continue;
|
||||
// push the top level nodes into nodes_in_topological_order in the order they were added
|
||||
// to ensure that is consistent.
|
||||
auto& nodes_in_original_order = Nodes();
|
||||
for (GraphNodes::ConstNodeIterator it = nodes_in_original_order.cbegin(); it != nodes_in_original_order.cend(); ++it)
|
||||
{
|
||||
const Node& node = *it;
|
||||
auto index = node.Index();
|
||||
|
||||
// find the top level nodes in the graph
|
||||
if (node.GetRelationships().input_edges.size() == 0 && index != sink_node_index_) {
|
||||
// add to the topological list, and ensure we skip these nodes when walking the graph
|
||||
nodes_in_topological_order.push_back(index);
|
||||
processed_nodes.insert(index);
|
||||
|
||||
// mark this as added as we've fully processed it and don't need to do it again later
|
||||
nodes_added_for_processing.insert(index);
|
||||
}
|
||||
}
|
||||
|
||||
if (children_visited_nodes.end() != children_visited_nodes.find(current)) {
|
||||
// children are done so we mark this one complete.
|
||||
visited_nodes.insert(current);
|
||||
nodes_in_topological_order.push_back(current);
|
||||
ancestor_nodes.erase(current);
|
||||
continue;
|
||||
// start at the bottom and work our way up the graph
|
||||
stack.push(sink_node_index_);
|
||||
|
||||
while (!stack.empty()) {
|
||||
const NodeIndex current = stack.top();
|
||||
stack.pop();
|
||||
|
||||
if (processed_nodes.find(current) != processed_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) {
|
||||
// we popped the stack and are back to a node that was added previously,
|
||||
// so we know all the upstream nodes from it have been fully processed,
|
||||
nodes_in_topological_order.push_back(current);
|
||||
processed_nodes.insert(current);
|
||||
output_nodes.erase(current);
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node* node = GetNode(current);
|
||||
if (!node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
stack.push(current);
|
||||
output_nodes.insert(current);
|
||||
|
||||
// push the node's inputs onto the stack in reverse order so that when we finish processing each one
|
||||
// and pop them from the stack they get added to nodes_in_topological_order in their original order
|
||||
for (auto iter = std::make_reverse_iterator(node->InputNodesEnd()),
|
||||
end = std::make_reverse_iterator(node->InputNodesBegin());
|
||||
iter != end; ++iter) {
|
||||
const NodeIndex idx = (*iter)->Index();
|
||||
if (output_nodes.find(idx) != output_nodes.end()) {
|
||||
Status status(LOTUS, FAIL, "Error: the graph is not acyclic.");
|
||||
return status;
|
||||
}
|
||||
|
||||
// avoid re-processing nodes
|
||||
if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) {
|
||||
stack.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
nodes_added_for_processing.insert(current);
|
||||
}
|
||||
|
||||
const Node* node = GetNode(current);
|
||||
if (!node) {
|
||||
continue;
|
||||
if (num_of_nodes_ >= 0 && static_cast<size_t>(num_of_nodes_) == nodes_in_topological_order.size()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (node->InputNodesBegin() == node->InputNodesEnd()) {
|
||||
// no children
|
||||
children_visited_nodes.insert(current);
|
||||
visited_nodes.insert(current);
|
||||
nodes_in_topological_order.push_back(current);
|
||||
ancestor_nodes.erase(current);
|
||||
continue;
|
||||
else {
|
||||
return Status(LOTUS, FAIL, "Error: the graph is not acyclic.");
|
||||
}
|
||||
|
||||
stack.push(current);
|
||||
|
||||
// mark as children done. by the time the node is popped off the stack again,
|
||||
// its children will have been processed
|
||||
children_visited_nodes.insert(current);
|
||||
|
||||
ancestor_nodes.insert(current);
|
||||
|
||||
// check children
|
||||
for (auto iter = node->InputNodesBegin(); iter != node->InputNodesEnd(); ++iter) {
|
||||
const NodeIndex idx = (*iter)->Index();
|
||||
if (ancestor_nodes.end() != ancestor_nodes.find(idx)) {
|
||||
Status status(LOTUS, FAIL, "Error: the graph is not acyclic.");
|
||||
return status;
|
||||
}
|
||||
|
||||
// avoid re-processing nodes
|
||||
if (children_visited_nodes.end() == children_visited_nodes.find(idx)) {
|
||||
stack.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (num_of_nodes_ >= 0 && static_cast<size_t>(num_of_nodes_) == nodes_in_topological_order.size()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(LOTUS, FAIL, "Error: the graph is not acyclic.");
|
||||
}
|
||||
}
|
||||
|
||||
bool FullyDefinedType(const TypeProto& type_proto) {
|
||||
switch (type_proto.value_case()) {
|
||||
case TypeProto::kTensorType: {
|
||||
|
@ -1559,112 +1571,34 @@ void Graph::CleanUnusedInitializers() {
|
|||
}
|
||||
}
|
||||
|
||||
GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...)
|
||||
Status Graph::SetGraphInputsOutputs() {
|
||||
// Reset graphInputs/graphOutputs/valueInfo state.
|
||||
auto& graph_inputs = MutableInputs();
|
||||
auto& graph_outputs = MutableOutputs();
|
||||
|
||||
graph_inputs.clear();
|
||||
graph_outputs.clear();
|
||||
value_info_.clear();
|
||||
|
||||
// Flag indicates that this graph is loaded from model file.
|
||||
// If it's true, then graph inputs and outputs will keep the same
|
||||
// as what are specified in the model, otherwise, graph inputs
|
||||
// and outputs will be inferred.
|
||||
const bool loaded_from_model_file = graph_proto_->input_size() != 0 ||
|
||||
graph_proto_->output_size() != 0 ||
|
||||
graph_proto_->value_info_size() != 0;
|
||||
|
||||
std::unordered_set<std::string> added_input_names{};
|
||||
|
||||
if (loaded_from_model_file) {
|
||||
// Collect all graph inputs/outputs specified in original graph proto
|
||||
std::unordered_set<std::string> specified_graph_inputs;
|
||||
std::unordered_set<std::string> specified_graph_outputs;
|
||||
std::unordered_set<std::string> specified_graph_value_info;
|
||||
std::unordered_set<std::string> specified_initializers;
|
||||
|
||||
for (auto& graph_output : graph_proto_->output()) {
|
||||
specified_graph_outputs.insert(graph_output.name());
|
||||
}
|
||||
|
||||
for (auto& graph_value_info : graph_proto_->value_info()) {
|
||||
specified_graph_value_info.insert(graph_value_info.name());
|
||||
}
|
||||
|
||||
for (auto& initializer : graph_proto_->initializer()) {
|
||||
specified_initializers.insert(initializer.name());
|
||||
}
|
||||
|
||||
// only add non-initializer to inputs
|
||||
for (auto& graph_input : graph_proto_->input()) {
|
||||
if (specified_initializers.find(graph_input.name()) == specified_initializers.end())
|
||||
specified_graph_inputs.insert(graph_input.name());
|
||||
void AssignNodeArgsIfChanged(const std::vector<const NodeArg*> new_graph_inputs, std::vector<const NodeArg*> &graph_inputs)
|
||||
{
|
||||
if (true || graph_inputs.size() != new_graph_inputs.size() ||
|
||||
std::any_of(graph_inputs.begin(), graph_inputs.end(), [new_graph_inputs](const NodeArg *input_arg)
|
||||
{
|
||||
for (auto new_input_arg : new_graph_inputs)
|
||||
{
|
||||
if (input_arg->Name() == new_input_arg->Name())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}))
|
||||
{
|
||||
graph_inputs = new_graph_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
void Graph::ComputeGraphInputsOutputsAndResetValues(std::vector<const NodeArg*> &new_graph_inputs,
|
||||
std::vector<const NodeArg*> &new_graph_outputs)
|
||||
{
|
||||
value_info_.clear();
|
||||
std::unordered_set<std::string> added_input_names{};
|
||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||
for (const auto& node : Nodes()) {
|
||||
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||
if (specified_graph_outputs.erase(output_def->Name()) >= 1) {
|
||||
graph_outputs.push_back(output_def);
|
||||
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||
if (output_def->Exists())
|
||||
output_name_to_node_arg.insert({ output_def->Name(), output_def });
|
||||
}
|
||||
output_name_to_node_arg.insert({output_def->Name(), output_def});
|
||||
}
|
||||
}
|
||||
// for any outputs using initializer, add to graph_outputs
|
||||
if (specified_graph_outputs.size() > 0) {
|
||||
for (const auto& name : specified_initializers) {
|
||||
if (specified_graph_outputs.erase(name) >= 1) {
|
||||
graph_outputs.push_back(FindNodeArg(name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!specified_graph_outputs.empty()) {
|
||||
std::string missing_list;
|
||||
for (auto& name : specified_graph_outputs)
|
||||
missing_list += name + " ";
|
||||
return Status(LOTUS, FAIL, "Some graph outputs do not exist in the graph. (" + missing_list + ")");
|
||||
}
|
||||
|
||||
for (const auto& node : Nodes()) {
|
||||
// Go thru all node's inputs.
|
||||
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||
if (!input_arg->Exists()) {
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (specified_graph_inputs.end() != specified_graph_inputs.find(input_arg->Name())) {
|
||||
if (added_input_names.insert(input_arg->Name()).second) {
|
||||
// The node input is specified as graph input.
|
||||
graph_inputs.push_back(input_arg);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
|
||||
if (output_name_to_node_arg.end() == output_arg_iter &&
|
||||
specified_initializers.end() == specified_initializers.find(input_arg->Name())) {
|
||||
// The node input is not specified as graph input,
|
||||
// and it's not fed by another node neither.
|
||||
return Status(LOTUS, FAIL, "Node input (" + input_arg->Name() + ") should be a graph input or initializer.");
|
||||
}
|
||||
|
||||
if (specified_graph_value_info.erase(input_arg->Name()) >= 1) {
|
||||
value_info_.push_back(input_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||
for (const auto& node : Nodes()) {
|
||||
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||
if (output_def->Exists())
|
||||
output_name_to_node_arg.insert({output_def->Name(), output_def});
|
||||
}
|
||||
}
|
||||
|
||||
// Init graph output args with all node output args.
|
||||
|
@ -1672,40 +1606,224 @@ Status Graph::SetGraphInputsOutputs() {
|
|||
|
||||
std::unordered_set<Node*> inner_nodes;
|
||||
for (const auto& node : Nodes()) {
|
||||
// Go thru all node's inputs.
|
||||
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||
if (!input_arg->Exists()) {
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
// Go thru all node's inputs.
|
||||
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||
if (!input_arg->Exists()) {
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
|
||||
if (output_name_to_node_arg.end() == output_arg_iter) {
|
||||
// This input arg should be fed when running evaluation.
|
||||
// it should be a graph input.
|
||||
const std::string& name = input_arg->Name();
|
||||
if (added_input_names.end() == added_input_names.find(name)) {
|
||||
// This graph input has not been added into <graph_inputs_>.
|
||||
if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
|
||||
graph_inputs.push_back(input_arg);
|
||||
added_input_names.insert(input_arg->Name());
|
||||
}
|
||||
} else if (graph_output_args.erase(output_arg_iter->first) >= 1) {
|
||||
// Remove the output arg name from graph outputs since it's
|
||||
// the input of another node, which we call it intermediate result
|
||||
// and store it in <m_valueinfo>.
|
||||
value_info_.push_back(input_arg);
|
||||
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
|
||||
if (output_name_to_node_arg.end() == output_arg_iter) {
|
||||
// This input arg should be fed when running evaluation.
|
||||
// it should be a graph input.
|
||||
const std::string& name = input_arg->Name();
|
||||
if (added_input_names.end() == added_input_names.find(name)) {
|
||||
// This graph input has not been added into <graph_inputs_>.
|
||||
if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
|
||||
new_graph_inputs.push_back(input_arg);
|
||||
added_input_names.insert(input_arg->Name());
|
||||
}
|
||||
}
|
||||
else if (graph_output_args.erase(output_arg_iter->first) >= 1) {
|
||||
// Remove the output arg name from graph outputs since it's
|
||||
// the input of another node, which we call it intermediate result
|
||||
// and store it in <m_valueinfo>.
|
||||
value_info_.push_back(input_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set graph outputs.
|
||||
for (auto& output_arg : graph_output_args) {
|
||||
graph_outputs.push_back(output_arg.second);
|
||||
auto nodes = Nodes();
|
||||
std::vector<const NodeArg*> sorted_new_graph_outputs;
|
||||
for (GraphNodes::ConstNodeIterator it = nodes.cbegin(); it != nodes.cend(); ++it)
|
||||
{
|
||||
const Node &node = *it;
|
||||
auto nodeOutputNodeArgs = node.OutputDefs();
|
||||
for (std::unordered_map<std::string, const NodeArg*>::iterator itPair = graph_output_args.begin();
|
||||
itPair != graph_output_args.end(); ++itPair)
|
||||
{
|
||||
const NodeArg* outputNodeArg = itPair->second;
|
||||
for (int i = 0; i < nodeOutputNodeArgs.size(); i++)
|
||||
{
|
||||
if (nodeOutputNodeArgs[i]->Name() == outputNodeArg->Name())
|
||||
{
|
||||
if (std::find_if(new_graph_outputs.begin(), new_graph_outputs.end(), [outputNodeArg](const NodeArg *nodeArg)
|
||||
{
|
||||
return outputNodeArg->Name() == nodeArg->Name();
|
||||
}) == new_graph_outputs.end())
|
||||
new_graph_outputs.push_back(outputNodeArg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...)
|
||||
Status Graph::SetGraphInputsOutputs() {
|
||||
// Reset graphInputs/graphOutputs/valueInfo state.
|
||||
auto& graph_inputs = MutableInputs();
|
||||
auto& graph_outputs = MutableOutputs();
|
||||
|
||||
graph_inputs.clear();
|
||||
graph_outputs.clear();
|
||||
value_info_.clear();
|
||||
|
||||
// Flag indicates that this graph is loaded from model file.
|
||||
// If it's true, then graph inputs and outputs will keep the same
|
||||
// as what are specified in the model, otherwise, graph inputs
|
||||
// and outputs will be inferred.
|
||||
const bool loaded_from_model_file = graph_proto_->input_size() != 0 ||
|
||||
graph_proto_->output_size() != 0 ||
|
||||
graph_proto_->value_info_size() != 0;
|
||||
|
||||
std::unordered_set<std::string> added_input_names{};
|
||||
|
||||
if (loaded_from_model_file) {
|
||||
// Collect all graph inputs/outputs specified in original graph proto
|
||||
std::unordered_set<std::string> specified_graph_inputs;
|
||||
std::unordered_set<std::string> specified_graph_outputs;
|
||||
std::unordered_set<std::string> specified_graph_value_info;
|
||||
std::unordered_set<std::string> specified_initializers;
|
||||
|
||||
for (auto& graph_output : graph_proto_->output()) {
|
||||
specified_graph_outputs.insert(graph_output.name());
|
||||
}
|
||||
|
||||
for (auto& graph_value_info : graph_proto_->value_info()) {
|
||||
specified_graph_value_info.insert(graph_value_info.name());
|
||||
}
|
||||
|
||||
for (auto& initializer : graph_proto_->initializer()) {
|
||||
specified_initializers.insert(initializer.name());
|
||||
}
|
||||
|
||||
// only add non-initializer to inputs
|
||||
for (auto& graph_input : graph_proto_->input()) {
|
||||
if (specified_initializers.find(graph_input.name()) == specified_initializers.end())
|
||||
specified_graph_inputs.insert(graph_input.name());
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||
|
||||
// add non-initializer outputs
|
||||
for (const auto& node : Nodes()) {
|
||||
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||
IGNORE_RETURN_VALUE(specified_graph_outputs.erase(output_def->Name()));
|
||||
output_name_to_node_arg.insert({ output_def->Name(), output_def });
|
||||
}
|
||||
}
|
||||
|
||||
// add any outputs using initializer
|
||||
if (specified_graph_outputs.size() > 0) {
|
||||
for (const auto& name : specified_initializers) {
|
||||
IGNORE_RETURN_VALUE(specified_graph_outputs.erase(name));
|
||||
output_name_to_node_arg.insert({ name, FindNodeArg(name) });
|
||||
}
|
||||
}
|
||||
|
||||
if (!specified_graph_outputs.empty()) {
|
||||
std::string missing_list;
|
||||
for (auto& name : specified_graph_outputs)
|
||||
missing_list += name + " ";
|
||||
return Status(LOTUS, FAIL, "Some graph outputs do not exist in the graph. (" + missing_list + ")");
|
||||
}
|
||||
|
||||
// preserve order of outputs
|
||||
for (auto& graph_output : graph_proto_->output()) {
|
||||
graph_outputs.push_back(output_name_to_node_arg.at(graph_output.name()));
|
||||
}
|
||||
|
||||
for (const auto& node : Nodes()) {
|
||||
// Go thru all node's inputs.
|
||||
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||
if (!input_arg->Exists()) {
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (specified_graph_inputs.end() != specified_graph_inputs.find(input_arg->Name())) {
|
||||
if (added_input_names.insert(input_arg->Name()).second) {
|
||||
// The node input is specified as graph input.
|
||||
graph_inputs.push_back(input_arg);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
|
||||
if (output_name_to_node_arg.end() == output_arg_iter &&
|
||||
specified_initializers.end() == specified_initializers.find(input_arg->Name())) {
|
||||
// The node input is not specified as graph input,
|
||||
// and it's not fed by another node neither.
|
||||
return Status(LOTUS, FAIL, "Node input (" + input_arg->Name() + ") should be a graph input or initializer.");
|
||||
}
|
||||
|
||||
if (specified_graph_value_info.erase(input_arg->Name()) >= 1) {
|
||||
value_info_.push_back(input_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||
std::vector<std::string> ordered_output_names;
|
||||
|
||||
for (const auto& node : Nodes()) {
|
||||
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||
if (output_def->Exists()) {
|
||||
output_name_to_node_arg.insert({ output_def->Name(), output_def });
|
||||
ordered_output_names.push_back(output_def->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Init graph output args with copy of all node output args.
|
||||
auto graph_output_args = output_name_to_node_arg;
|
||||
|
||||
std::unordered_set<Node*> inner_nodes;
|
||||
for (const auto& node : Nodes()) {
|
||||
// Go thru all node's inputs.
|
||||
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||
if (!input_arg->Exists()) {
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
|
||||
if (output_name_to_node_arg.end() == output_arg_iter) {
|
||||
// This input arg should be fed when running evaluation.
|
||||
// it should be a graph input.
|
||||
const std::string& name = input_arg->Name();
|
||||
if (added_input_names.end() == added_input_names.find(name)) {
|
||||
// This graph input has not been added into <graph_inputs_>.
|
||||
if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) {
|
||||
graph_inputs.push_back(input_arg);
|
||||
}
|
||||
|
||||
added_input_names.insert(input_arg->Name());
|
||||
}
|
||||
}
|
||||
else if (graph_output_args.erase(output_arg_iter->first) >= 1) {
|
||||
// Remove the output arg name from graph outputs since it's
|
||||
// the input of this node, which we call it intermediate result
|
||||
// and store it in <m_valueinfo>.
|
||||
value_info_.push_back(input_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set graph outputs
|
||||
for (auto& name : ordered_output_names) {
|
||||
auto end = graph_output_args.end();
|
||||
auto graph_output = graph_output_args.find(name);
|
||||
if (graph_output != end) {
|
||||
graph_outputs.push_back(graph_output->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool GraphBase::IsSourceNode(NodeIndex index) const noexcept {
|
||||
|
|
|
@ -99,6 +99,9 @@ class Graph : public GraphBase {
|
|||
::onnxruntime::common::Status VerifyNodeAndOpMatch(const std::vector<NodeIndex>& nodes_in_topological_order,
|
||||
const std::unordered_map<std::string, Node*>& output_args);
|
||||
|
||||
void ComputeGraphInputsOutputsAndResetValues(std::vector<const NodeArg*> &new_graph_inputs,
|
||||
std::vector<const NodeArg*> &new_graph_outputs);
|
||||
|
||||
// Set graph inputs/outputs when resolving a graph..
|
||||
::onnxruntime::common::Status SetGraphInputsOutputs();
|
||||
|
||||
|
|
|
@ -2796,11 +2796,14 @@ public:
|
|||
|
||||
const auto& inputLayout = Input(DATA)->GetSampleLayout();
|
||||
|
||||
// running statistics inputs must be learnable parameters, since we update them directly here
|
||||
for (size_t i = RUN_MEAN; i < GetNumInputs(); i++)
|
||||
//if (!Input(i)->Is<LearnableParameter<ElemType>>()) // somehow this does not compile on gcc (works on VS)
|
||||
if (!dynamic_cast<LearnableParameter<StatType>*>(this->template TypedInput<StatType>(i).get()))
|
||||
InvalidArgument("%ls: Inputs [%d..%d] must be learnable parameters.", NodeDescription().c_str(), (int)RUN_MEAN, (int)GetNumInputs());
|
||||
if (Environment().IsTraining())
|
||||
{
|
||||
// running statistics inputs must be learnable parameters, since we update them directly here
|
||||
for (size_t i = RUN_MEAN; i < GetNumInputs(); i++)
|
||||
//if (!Input(i)->Is<LearnableParameter<ElemType>>()) // somehow this does not compile on gcc (works on VS)
|
||||
if (!dynamic_cast<LearnableParameter<StatType>*>(this->template TypedInput<StatType>(i).get()))
|
||||
InvalidArgument("%ls: Inputs [%d..%d] must be learnable parameters.", NodeDescription().c_str(), (int)RUN_MEAN, (int)GetNumInputs());
|
||||
}
|
||||
|
||||
// infer dimensions of learnable parameters
|
||||
// BUGBUG: Parameter dimensions are totally wrong. E.g. a valid spatial bias for [15 x 15 x 32] is currently [32 x 1].
|
||||
|
|
|
@ -143,6 +143,18 @@ def verify_one_input(model, data, tmpdir, name, device=None, loaded_model=None,
|
|||
verify_node_names(model, loaded_model)
|
||||
return loaded_model
|
||||
|
||||
def run_model(model, data, device=None):
|
||||
feed = {}
|
||||
if len(model.arguments) == 1:
|
||||
feed[model.arguments[0]] = data
|
||||
elif len(model.arguments) > 1:
|
||||
assert len(model.arguments) == len(data)
|
||||
for i in range(len(model.arguments)):
|
||||
feed[model.arguments[i]] = data[i]
|
||||
|
||||
o = model.eval(feed, device=device)
|
||||
return o
|
||||
|
||||
def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=None):
|
||||
# data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused.
|
||||
data = deepcopy(data)
|
||||
|
@ -152,17 +164,20 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
|
|||
if is_list_of_sparse(data):
|
||||
dataOnnx = transpose_dynamic_axis(sparse_to_dense(data))
|
||||
else:
|
||||
dataOnnx = transpose_dynamic_axis(data)
|
||||
if (type(data) == list):
|
||||
dataOnnx = []
|
||||
for i in range(0, len(data)):
|
||||
if (model.arguments[i].has_sequence_axis()):
|
||||
dataOnnx.append(transpose_dynamic_axis(data[i]))
|
||||
else:
|
||||
dataOnnx.append(data[i])
|
||||
else:
|
||||
dataOnnx = transpose_dynamic_axis(data)
|
||||
|
||||
loaded_model, onnx_model, test_model_path, test_data_path = create_and_populate_onnx_test_case_with_model_conversion(model, tmpdir, name, loaded_model)
|
||||
|
||||
if device:
|
||||
o0 = model.eval({model.arguments[0]:data}, device=device)
|
||||
o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx}, device=device)
|
||||
else:
|
||||
o0 = model.eval({model.arguments[0]:data})
|
||||
o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx})
|
||||
|
||||
o0 = run_model(model, data, device=device)
|
||||
o1 = run_model(loaded_model, dataOnnx, device=device)
|
||||
|
||||
## if there is a sequence axis in the output, it must be swapped with batch axis
|
||||
## to match the original CNTK model's output
|
||||
|
@ -391,16 +406,23 @@ def verify_BN(x, init_scale, init_bias, mean, var, epsilon, spatial, tmpdir, dty
|
|||
epsilon=epsilon)
|
||||
|
||||
loaded_model = None
|
||||
test_base_name = 'Spatial' if spatial else ''
|
||||
test_base_name = test_base_name + ('BatchNormalization_float16' if dtype==np.float16 else 'BatchNormalization_float32')
|
||||
|
||||
for i in range(len(x)):
|
||||
if dtype==np.float16:
|
||||
loaded_model = verify_one_input(op_node, x[i], tmpdir, 'BatchNormalization_float16' + str(i), loaded_model=loaded_model, rtol = 1e-03, atol = 1e-03)
|
||||
loaded_model = verify_one_input(op_node, x[i], tmpdir, test_base_name + str(i), loaded_model=loaded_model, rtol = 1e-03, atol = 1e-03)
|
||||
else:
|
||||
loaded_model = verify_one_input(op_node, x[i], tmpdir, 'BatchNormalization_float32' + str(i), loaded_model=loaded_model)
|
||||
|
||||
loaded_model = verify_one_input(op_node, x[i], tmpdir, test_base_name + str(i), loaded_model=loaded_model)
|
||||
|
||||
non_spatial_float16_skip_message = str('Test is skipped with float16 data because CNTK ONNX importer in float16 case assumes mean/var inputs being constant.'
|
||||
'this is not always true because in CNTK non-spatial case mean/var may need to be reshaped before pass to the BN function.'
|
||||
'In general import of BatchNormalization(float16) need to be fixed to take any input as mean/var, etc.')
|
||||
# Case 1 - Non-Spatial BN with More > 1 batches
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_BatchNormalization(tmpdir, dtype):
|
||||
if dtype == np.float16:
|
||||
pytest.skip(non_spatial_float16_skip_message)
|
||||
sample = [ # 5 samples having 4 classes
|
||||
[1, 1, 2, 3],
|
||||
[0, 0, 0, 0],
|
||||
|
@ -421,6 +443,7 @@ def test_BatchNormalization(tmpdir, dtype):
|
|||
# Case 2 - Spatial BN with More > 1 batches
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_SpatialBatchNormalization(tmpdir, dtype):
|
||||
np.random.seed(0)
|
||||
x = np.random.randn(2, 3, 4, 5).astype(dtype)
|
||||
scale = np.random.randn(3).astype(np.float32)
|
||||
bias = np.random.randn(3).astype(np.float32)
|
||||
|
@ -1131,6 +1154,103 @@ def test_MatMul_nd_2inputs_2(tmpdir, dtype):
|
|||
model = C.times(x, y)
|
||||
verify_two_input(model, data0, data1, tmpdir, 'MatMul_n_3')
|
||||
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_CNTK_Times_To_ONNX_MatMul(tmpdir, dtype):
|
||||
def generate_matmul_data(input_variable, batch_size, sequence_size):
|
||||
np.random.seed(0)
|
||||
data_shape = ()
|
||||
if input_variable.has_batch_axis():
|
||||
data_shape = (*data_shape, batch_size)
|
||||
if input_variable.has_sequence_axis():
|
||||
data_shape = (*data_shape, sequence_size)
|
||||
data_shape = (*data_shape, *input_variable.shape)
|
||||
print(data_shape)
|
||||
data = np.random.randn(*data_shape).astype(np.float32)
|
||||
return data
|
||||
|
||||
batch_size = 1
|
||||
sequence_length = 3
|
||||
input1_shape = (2, 3, 4)
|
||||
input2_shape = (3, 4, 5, 6)
|
||||
output_rank = 2
|
||||
|
||||
## data_x_data
|
||||
x = C.input_variable(input1_shape, dynamic_axes = [])
|
||||
y = C.input_variable(input2_shape, dynamic_axes = [])
|
||||
model = C.times(x, y, output_rank = output_rank)
|
||||
data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
verify_two_input(model, data0, data1, tmpdir, 'times_data_x_data')
|
||||
|
||||
###batch_x_data
|
||||
x = C.input_variable(input1_shape, name = "x")
|
||||
y = C.input_variable(input2_shape, dynamic_axes = [], name = "y")
|
||||
model = C.times(x, y, output_rank = output_rank)
|
||||
data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
verify_two_input(model, data0, data1, tmpdir, 'batch_x_data')
|
||||
|
||||
## data_x_batch
|
||||
x = C.input_variable(input1_shape, dynamic_axes = [])
|
||||
y = C.input_variable(input2_shape)
|
||||
model = C.times(x, y, output_rank = output_rank)
|
||||
data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
verify_two_input(model, data0, data1, tmpdir, 'data_x_batch')
|
||||
|
||||
## batch_x_batch
|
||||
x = C.input_variable(input1_shape)
|
||||
y = C.input_variable(input2_shape)
|
||||
model = C.times(x, y, output_rank = output_rank)
|
||||
data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
verify_two_input(model, data0, data1, tmpdir, 'batch_x_batch')
|
||||
|
||||
### sequence_x_data
|
||||
# TODO: ONNX importer cannot handle sequence and batch axes both being free diemention static axis
|
||||
#x = C.sequence.input_variable(input1_shape)
|
||||
#y = C.input_variable(input2_shape, dynamic_axes = [])
|
||||
#model = C.times(x, y, output_rank = output_rank)
|
||||
#data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
#data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
#verify_sequence_model(model, [data0, data1], tmpdir, 'sequence_x_data')
|
||||
|
||||
### data_x_sequence
|
||||
#TODO: ONNX importer cannot handle sequence and batch axes both being free diemention static axis
|
||||
#x = C.input_variable(input1_shape, dynamic_axes = [])
|
||||
#y = C.sequence.input_variable(input2_shape)
|
||||
#model = C.times(x, y, output_rank = output_rank)
|
||||
#data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
#data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
#verify_sequence_model(model, [data0, data1], tmpdir, 'data_x_sequence')
|
||||
|
||||
## sequence_x_sequence
|
||||
# TODO: ONNX importer cannot handle sequence and batch axes both being free diemention static axis
|
||||
#x = C.sequence.input_variable(input1_shape)
|
||||
#y = C.sequence.input_variable(input2_shape)
|
||||
#model = C.times(x, y, output_rank = output_rank)
|
||||
#data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
#data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
#verify_sequence_model(model, [data0, data1], tmpdir, 'sequence_x_sequence')
|
||||
|
||||
## sequence_x_batch
|
||||
# TODO: ONNX importer cannot handle sequence and batch axes both being free diemention static axis
|
||||
#x = C.sequence.input_variable(input1_shape)
|
||||
#y = C.input_variable(input2_shape)
|
||||
#model = C.times(x, y, output_rank = output_rank)
|
||||
#data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
#data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
#verify_sequence_model(model, [data0, data1], tmpdir, 'sequence_x_batch')
|
||||
|
||||
## batch_x_sequence
|
||||
# TODO: ONNX importer cannot handle sequence and batch axes both being free diemention static axis
|
||||
#x = C.input_variable(input1_shape)
|
||||
#y = C.sequence.input_variable(input2_shape)
|
||||
#model = C.times(x, y, output_rank = output_rank)
|
||||
#data0 = generate_matmul_data(x, batch_size, sequence_length)
|
||||
#data1 = generate_matmul_data(y, batch_size, sequence_length)
|
||||
#verify_sequence_model(model, [data0, data1], tmpdir, 'batch_x_sequence')
|
||||
|
||||
#Max
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_Max(tmpdir, dtype):
|
||||
|
|
|
@ -81,6 +81,8 @@ def save_cntk_data_as_onnx_tensor(file_path, variable, data, onnx_value_info_pro
|
|||
# swith to onnx shape: (sequence, batch, ...)
|
||||
if is_list_of_sparse(data):
|
||||
data = sparse_to_dense(data)
|
||||
elif type(data)==scipy.sparse.csr.csr_matrix:
|
||||
data = data.todense()
|
||||
|
||||
# compare free_dim indices between variable with onnx_value_info_proto
|
||||
# they are at index 0 and 1.
|
||||
|
|
Загрузка…
Ссылка в новой задаче