support folding onnx>2GB with onnxruntime (#528)
* add round * avoid create large constant when converting onnx, some other fix * port nofuse flag, fix weight path of optimized onnx model * fix
This commit is contained in:
Родитель
35e1a76f25
Коммит
56f3ab5c4b
|
@ -93,6 +93,10 @@ elif args.graph_optimization_level == 'ORT_ENABLE_EXTENDED':
|
|||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
||||
if args.optimized_model_filepath != '':
|
||||
sess_options.optimized_model_filepath = args.optimized_model_filepath
|
||||
sess_options.add_session_config_entry(
|
||||
"session.optimized_model_external_initializers_file_name", os.path.basename(args.optimized_model_filepath) + ".data"
|
||||
)
|
||||
sess_options.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100")
|
||||
|
||||
for k, v in args.symbolic_dims.items():
|
||||
sess_options.add_free_dimension_override_by_name(k, int(v))
|
||||
|
|
|
@ -84,6 +84,7 @@
|
|||
#include "nnfusion/core/operators/op_define/result.hpp"
|
||||
#include "nnfusion/core/operators/op_define/reverse.hpp"
|
||||
#include "nnfusion/core/operators/op_define/reverse_sequence.hpp"
|
||||
#include "nnfusion/core/operators/op_define/round.hpp"
|
||||
#include "nnfusion/core/operators/op_define/rsqrt.hpp"
|
||||
#include "nnfusion/core/operators/op_define/select.hpp"
|
||||
#include "nnfusion/core/operators/op_define/select_and_scatter.hpp"
|
||||
|
|
|
@ -81,7 +81,7 @@ LanguageUnit_p cuda::Gather1D::emit_function_body()
|
|||
}
|
||||
|
||||
lu << "int64_t gather_i = __ldg(indices + indices_i);\n";
|
||||
lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n";
|
||||
lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n";
|
||||
lu << "if (gather_i >= " << gather_dim_size << ")\n"
|
||||
<< " out[i] = 0;\n"
|
||||
<< "else\n";
|
||||
|
@ -194,7 +194,7 @@ LanguageUnit_p cuda::Gather1DGrad::emit_function_body()
|
|||
}
|
||||
|
||||
lu << "int64_t gather_i = __ldg(indices + indices_i);\n";
|
||||
lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n";
|
||||
lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n";
|
||||
lu << "if (gather_i < " << gather_dim_size << ")\n";
|
||||
lu.block_begin();
|
||||
{
|
||||
|
|
|
@ -68,6 +68,7 @@ set(SRC
|
|||
op_define/result.cpp
|
||||
op_define/reverse_sequence.cpp
|
||||
op_define/reverse.cpp
|
||||
op_define/round.cpp
|
||||
op_define/rsqrt.cpp
|
||||
op_define/select_and_scatter.cpp
|
||||
op_define/select.cpp
|
||||
|
|
|
@ -313,10 +313,15 @@ namespace nnfusion
|
|||
{
|
||||
config[alias_name + "_dtype"] = "int64";
|
||||
}
|
||||
else if (d_type == element::u8)
|
||||
{
|
||||
// hack!!!
|
||||
config[alias_name + "_dtype"] = "int8";
|
||||
}
|
||||
else
|
||||
{
|
||||
NNFUSION_CHECK_FAIL()
|
||||
<< "Unhandled type: " << d_type
|
||||
<< "Unhandled type for " << input_name << ": " << d_type
|
||||
<< ", antares currently supports int8/16/32/64, float16/32/64";
|
||||
}
|
||||
auto shape = tensor->get_shape();
|
||||
|
|
|
@ -31,6 +31,7 @@ static const std::unordered_map<std::string, element_op> ElementOpMap = {
|
|||
{"Sin", element_op("sin", "")},
|
||||
{"Sinh", element_op("sinh", "")},
|
||||
{"Sqrt", element_op("sqrt", "")},
|
||||
{"Round", element_op("round", "x0.call(`round`)")},
|
||||
{"Rsqrt", element_op("rsqrt", "")},
|
||||
{"Tan", element_op("tan", "")},
|
||||
{"Tanh", element_op("tanh", "")},
|
||||
|
@ -196,6 +197,7 @@ REGISTER_ELEM_OP(Relu)
|
|||
REGISTER_ELEM_OP(Relu6)
|
||||
REGISTER_ELEM_OP(ReluBackprop)
|
||||
REGISTER_ELEM_OP(Relu6Backprop)
|
||||
REGISTER_ELEM_OP(Round)
|
||||
REGISTER_ELEM_OP(Sigmoid)
|
||||
REGISTER_ELEM_OP(SigmoidBackprop)
|
||||
REGISTER_ELEM_OP(Equal)
|
||||
|
|
|
@ -8,17 +8,18 @@
|
|||
|
||||
REGISTER_OP(Trilu)
|
||||
.infershape([](std::shared_ptr<graph::GNode> curr) -> void {
|
||||
curr->set_output_type_and_shape(0, curr->get_input_element_type(0), curr->get_input_shape(0));
|
||||
})
|
||||
curr->set_output_type_and_shape(
|
||||
0, curr->get_input_element_type(0), curr->get_input_shape(0));
|
||||
})
|
||||
.translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {
|
||||
auto input_shape_0 = curr->get_input_shape(0);
|
||||
assert(input_shape_0.size() >= 2);
|
||||
std::string k_str = "";
|
||||
if(curr->get_input_size() == 2)
|
||||
k_str = "+ input1[0]";
|
||||
if (curr->get_input_size() == 2)
|
||||
k_str = "+ input1[0]";
|
||||
auto op = static_pointer_cast<nnfusion::op::GenericOp>(curr->get_op_ptr());
|
||||
auto& cfg = op->localOpConfig.getRoot();
|
||||
bool upper = cfg["upper"].is_null()?true:int64_t(cfg["upper"])!=0;
|
||||
bool upper = cfg["upper"].is_null() ? true : int64_t(cfg["upper"]) != 0;
|
||||
auto input_layout = op::create_layout_from_dims(input_shape_0);
|
||||
auto dim_a = input_layout[input_layout.size() - 2];
|
||||
auto dim_b = input_layout[input_layout.size() - 1];
|
||||
|
@ -28,13 +29,11 @@ REGISTER_OP(Trilu)
|
|||
element::Type::nnfusion_element_type_to_dtype_string(curr->get_element_type(), dtype);
|
||||
NNFUSION_CHECK(ret);
|
||||
|
||||
std::string condition = upper?dim_b+">="+dim_a+k_str:dim_a+k_str+">="+dim_b;
|
||||
std::string condition = upper ? dim_b + ">=" + dim_a + k_str : dim_a + k_str + ">=" + dim_b;
|
||||
|
||||
auto expression = op::create_code_from_template(
|
||||
"@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, const(0).cast(`@dtype@`));", {
|
||||
{"input_layout", join(input_layout)},
|
||||
{"condition", condition},
|
||||
{"dtype", dtype}
|
||||
});
|
||||
"@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, "
|
||||
"const(0).cast(`@dtype@`));",
|
||||
{{"input_layout", join(input_layout)}, {"condition", condition}, {"dtype", dtype}});
|
||||
return expression;
|
||||
});
|
|
@ -32,7 +32,7 @@ namespace nnfusion
|
|||
std::shared_ptr<graph::GNode> fused_node);
|
||||
std::string get_fused_ir2() { return fused_op_ir2; };
|
||||
std::string get_plan_rule();
|
||||
|
||||
bool get_is_memcpy() { return is_memcpy; }
|
||||
protected:
|
||||
void assemble_inputs_and_outputs();
|
||||
|
||||
|
@ -41,4 +41,4 @@ namespace nnfusion
|
|||
bool is_memcpy;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
// Microsoft (c) 2020, NNFusion Team
|
||||
|
||||
#include "round.hpp"
|
||||
|
||||
using namespace nnfusion::op;
|
||||
|
||||
Round::Round()
|
||||
: ElementwiseArithmetic("Round")
|
||||
{
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
// Microsoft (c) 2020, NNFusion Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "nnfusion/core/operators/util/elementwise_arithmetic.hpp"
|
||||
|
||||
namespace nnfusion
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
/// \brief Elementwise cosine operation.
|
||||
class Round : public ElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
/// \brief Constructs a round operation.
|
||||
Round();
|
||||
};
|
||||
}
|
||||
}
|
|
@ -879,8 +879,8 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru
|
|||
lu << "Debug(\"" << node_name << ", " << out_name << member_name << "_f32\", "
|
||||
<< "fp32tensors, \"" << join(kernel->m_context->input_names) << "\", "
|
||||
<< kernel->m_context->outputs[i]->size(false) << ");\n";
|
||||
lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, "
|
||||
<< max_tensor_size <<"));\n";
|
||||
lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, " << max_tensor_size
|
||||
<< "));\n";
|
||||
}
|
||||
else if (element::get_backend_cstring(
|
||||
kernel->m_context->outputs[i]->get_element_type()) == "float")
|
||||
|
|
|
@ -19,6 +19,7 @@ using namespace nnfusion::kernels;
|
|||
|
||||
DEFINE_string(ftune_output_file, "", "the output json file path");
|
||||
DEFINE_string(ftune_input_file, "", "the input json file path");
|
||||
DEFINE_bool(fnofuse, false, "Disable element-wise fusion");
|
||||
DEFINE_string(ffusion_skiplist, "", "List of op types that skips in fusion");
|
||||
DECLARE_string(fdefault_device);
|
||||
|
||||
|
@ -84,6 +85,14 @@ namespace
|
|||
});
|
||||
return nodes;
|
||||
}
|
||||
|
||||
string ir_add_tag(const string& ir, const string& tag)
|
||||
{
|
||||
if (ir.find("## @:") != string::npos)
|
||||
return ir + "|" + tag;
|
||||
else
|
||||
return ir + "## @: " + tag;
|
||||
}
|
||||
}
|
||||
|
||||
class RegisterFusionOptimizer
|
||||
|
@ -138,11 +147,13 @@ public:
|
|||
fuse_from_node(tnode, true);
|
||||
}
|
||||
}
|
||||
inline_lightweighted_ops();
|
||||
auto groups = extract_fusion_group();
|
||||
for (auto group : groups)
|
||||
{
|
||||
insert_fuse_group(group);
|
||||
}
|
||||
if (!FLAGS_fnofuse)
|
||||
for (auto group : groups)
|
||||
{
|
||||
insert_fuse_group(group);
|
||||
}
|
||||
auto nodes = nlohmann::json().array();
|
||||
for (auto& node : find_topo_sort_priority(m_graph))
|
||||
{
|
||||
|
@ -151,10 +162,12 @@ public:
|
|||
auto str = nnfusion::op::get_translation_v2(node);
|
||||
if (skip_ops.count(node->get_op_type()))
|
||||
{
|
||||
if (str.find("## @:") != string::npos)
|
||||
str += "|skip";
|
||||
else
|
||||
str += "## @: skip";
|
||||
str = ir_add_tag(str, "skip");
|
||||
}
|
||||
if (node->get_op_type() == "Fused" &&
|
||||
std::dynamic_pointer_cast<op::Fused>(node->get_op_ptr())->get_is_memcpy())
|
||||
{
|
||||
str = ir_add_tag(str, "memcpy");
|
||||
}
|
||||
auto edge = nlohmann::json().array();
|
||||
for (auto& e : node->get_in_edges())
|
||||
|
@ -173,7 +186,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
vector<shared_ptr<FuseGroup>> extract_fusion_group()
|
||||
vector<shared_ptr<FuseGroup>> extract_fusion_group() const
|
||||
{
|
||||
unordered_map<int, shared_ptr<FuseGroup>> groups;
|
||||
vector<shared_ptr<FuseGroup>> result;
|
||||
|
@ -195,6 +208,85 @@ private:
|
|||
return result;
|
||||
}
|
||||
|
||||
bool is_lightweighted_op(const shared_ptr<GNode>& node)
|
||||
{
|
||||
auto type = node->get_op_type();
|
||||
if (type == "Slice" || type == "Broadcast")
|
||||
return true;
|
||||
if (type == "Reshape")
|
||||
{
|
||||
auto op = std::dynamic_pointer_cast<op::Reshape>(node->get_op_ptr());
|
||||
auto order = op->get_input_order();
|
||||
if (order.empty())
|
||||
return true;
|
||||
|
||||
bool is_lower_dim_kept = order.back() == order.size() - 1;
|
||||
return is_lower_dim_kept;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void inline_lightweighted_ops()
|
||||
{
|
||||
// Iterate over all independent groups
|
||||
// inline first group into second if:
|
||||
// 1. first group has one output
|
||||
// 2. first group are all light weighted ops
|
||||
// 3. all ops not in skip lists
|
||||
unordered_map<int, shared_ptr<FuseGroup>> map;
|
||||
vector<shared_ptr<FuseGroup>> groups;
|
||||
for (auto& tnode : node_list_)
|
||||
{
|
||||
if (tnode->node_->get_op_ptr()->is_tensor_op())
|
||||
continue;
|
||||
if (tnode->group_id_ < 0)
|
||||
{
|
||||
auto f = make_shared<FuseGroup>();
|
||||
f->nodes.insert(tnode->node_);
|
||||
groups.push_back(f);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!map.count(tnode->group_id_))
|
||||
{
|
||||
map[tnode->group_id_] = make_shared<FuseGroup>();
|
||||
}
|
||||
map[tnode->group_id_]->nodes.insert(tnode->node_);
|
||||
}
|
||||
}
|
||||
for (auto& kv : map)
|
||||
groups.push_back(kv.second);
|
||||
|
||||
for (auto& group : groups)
|
||||
{
|
||||
bool group_is_lightweighted = true;
|
||||
unordered_set<shared_ptr<GNode>> group_outputs;
|
||||
for (auto& node : group->nodes)
|
||||
{
|
||||
group_is_lightweighted &= is_lightweighted_op(node);
|
||||
for (auto& edge : node->get_out_edges())
|
||||
{
|
||||
if (!group->nodes.count(edge->get_dst()))
|
||||
group_outputs.insert(edge->get_dst());
|
||||
}
|
||||
}
|
||||
if (group_outputs.size() == 0)
|
||||
continue;
|
||||
auto& output_node = *group_outputs.begin();
|
||||
auto& tag_output_node = node_map_[output_node];
|
||||
bool op_skip = skip_ops.count(output_node->get_op_type());
|
||||
for (auto& node : group->nodes)
|
||||
op_skip |= skip_ops.count(node->get_op_type());
|
||||
if (group_is_lightweighted && !op_skip && group_outputs.size() == 1)
|
||||
{
|
||||
if (tag_output_node->group_id_ < 0)
|
||||
tag_output_node->group_id_ = cur_group_++;
|
||||
for (auto& node : group->nodes)
|
||||
node_map_[node]->group_id_ = tag_output_node->group_id_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void insert_fuse_group(shared_ptr<FuseGroup> group)
|
||||
{
|
||||
// get a meaningful name
|
||||
|
@ -453,4 +545,4 @@ bool RegisterFusionPass::run_on_graph(std::shared_ptr<Graph>& graph)
|
|||
applier.apply(FLAGS_ftune_input_file);
|
||||
NNFUSION_LOG(INFO) << "RegisterFusionPass Done";
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -68,10 +68,9 @@ namespace nnfusion
|
|||
"(models/pytorch2onnx/increase_precision.py)";
|
||||
string script_path =
|
||||
nnfusion::codegen::get_file_from_templates("onnx/increase_precision.py");
|
||||
string cmd = "python3 " + script_path +
|
||||
" --file " +
|
||||
m_path + " --mp_file " + mp_filename;
|
||||
|
||||
string cmd =
|
||||
"python3 " + script_path + " --file " + m_path + " --mp_file " + mp_filename;
|
||||
|
||||
int sys_ret = system(cmd.c_str());
|
||||
// NNFUSION_LOG(INFO) << "mix precision model path: " << mp_filename;
|
||||
opt_fin = std::ifstream(mp_filename.c_str());
|
||||
|
@ -86,7 +85,7 @@ namespace nnfusion
|
|||
"check error messages reported by the tool, fallback";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
string optimized_filename = string(nnfusion::tmpnam(nullptr));
|
||||
if (FLAGS_fort_folding)
|
||||
{
|
||||
|
@ -112,6 +111,7 @@ namespace nnfusion
|
|||
dim_params_str += "}\'";
|
||||
cmd += dim_params_str;
|
||||
}
|
||||
NNFUSION_LOG(INFO) << "Executing: " << cmd;
|
||||
int sys_ret = system(cmd.c_str());
|
||||
opt_fin = std::ifstream(optimized_filename.c_str());
|
||||
if (sys_ret == 0 && opt_fin.is_open())
|
||||
|
@ -128,11 +128,11 @@ namespace nnfusion
|
|||
std::ifstream ifs{m_path, std::ios::in | std::ios::binary};
|
||||
NNFUSION_CHECK(ifs.is_open()) << "failure opening file:" + path;
|
||||
string model_dir = "";
|
||||
string weight_path = FLAGS_fincrease_precision ? m_path : path;
|
||||
auto pos = weight_path.rfind("/");
|
||||
// string weight_path = FLAGS_fincrease_precision ? m_path : path;
|
||||
auto pos = m_path.rfind("/");
|
||||
if (pos != std::string::npos)
|
||||
{
|
||||
model_dir = weight_path.substr(0, pos);
|
||||
model_dir = m_path.substr(0, pos);
|
||||
}
|
||||
|
||||
auto graph = load_onnx_model(ifs, model_dir, dim_params);
|
||||
|
@ -141,7 +141,10 @@ namespace nnfusion
|
|||
{
|
||||
remove(optimized_filename.c_str());
|
||||
}
|
||||
|
||||
if (std::ifstream((optimized_filename + ".data").c_str()).good())
|
||||
{
|
||||
remove((optimized_filename + ".data").c_str());
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
} // namespace frontend
|
||||
|
|
|
@ -52,19 +52,27 @@ namespace nnfusion
|
|||
NNFUSION_CHECK(nnfusion::shape_size(value.get_shape()) == 1);
|
||||
const_op = make_constant_op(
|
||||
value.get_ng_type(),
|
||||
Shape(std::begin(output_shape), std::end(output_shape)),
|
||||
Shape{1},
|
||||
value);
|
||||
// const_op = make_constant_op(
|
||||
// value.get_ng_type(),
|
||||
// Shape(std::begin(output_shape), std::end(output_shape)),
|
||||
// value);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto vec = std::vector<float>{0};
|
||||
const_op = std::make_shared<op::Constant>(element::f32, Shape{1}, vec);
|
||||
}
|
||||
|
||||
const_op->set_name(node_proto.output(0));
|
||||
const_op->set_global_consistent_name(node_proto.output(0));
|
||||
// const_op->set_name(node_proto.output(0));
|
||||
// const_op->set_global_consistent_name(node_proto.output(0));
|
||||
auto const_gnode = m_graph->add_node_and_edge(const_op, graph::GNodeVector({}));
|
||||
|
||||
const_gnode = make_broadcast_node(const_gnode, Shape(std::begin(output_shape), std::end(output_shape)), m_graph);
|
||||
const_gnode->get_op_ptr()->set_name(node_proto.output(0));
|
||||
const_gnode->get_op_ptr()->set_global_consistent_name(node_proto.output(0));
|
||||
|
||||
|
||||
return {{node_proto.output(0), const_gnode}};
|
||||
}
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ namespace nnfusion
|
|||
else
|
||||
NNFUSION_CHECK_FAIL() << "non-supported data type for Range op: "
|
||||
<< element_type.c_type_string();
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace set_11
|
||||
|
|
|
@ -65,6 +65,12 @@ namespace nnfusion
|
|||
{
|
||||
using set_1::TranslateUnaryOp;
|
||||
}
|
||||
|
||||
namespace set_11
|
||||
{
|
||||
using set_1::TranslateUnaryOp;
|
||||
}
|
||||
|
||||
namespace set_13
|
||||
{
|
||||
using set_1::TranslateUnaryOp;
|
||||
|
|
|
@ -433,6 +433,7 @@ namespace nnfusion
|
|||
REGISTER_OPERATOR("Relu", 1, TranslateUnaryOp<op::Relu>);
|
||||
REGISTER_OPERATOR("Reshape", 1, TranslateReshapeOp);
|
||||
REGISTER_OPERATOR("ReshapeGrad", 1, TranslateReshapeGradOp);
|
||||
REGISTER_OPERATOR("Round", 11, TranslateUnaryOp<op::Round>);
|
||||
//REGISTER_OPERATOR("Selu", 1, selu);
|
||||
REGISTER_OPERATOR("Shape", 1, TranslateShapeOp);
|
||||
REGISTER_OPERATOR("Shape", 15, TranslateShapeOp);
|
||||
|
|
Загрузка…
Ссылка в новой задаче