Support for node names in DMLX graphs (#389)
This commit is contained in:
Родитель
e491f5bcb1
Коммит
b95f3cf77b
|
@ -22,6 +22,7 @@
|
|||
#include <utility>
|
||||
#include <type_traits>
|
||||
#include <functional>
|
||||
#include <stack>
|
||||
|
||||
#include <wrl/client.h> // For Microsoft::WRL::ComPtr
|
||||
|
||||
|
@ -37,6 +38,14 @@
|
|||
#define DMLX_OPTIONAL_EXTENDED
|
||||
#endif
|
||||
|
||||
#if __cpp_exceptions
|
||||
#include <stdexcept>
|
||||
#endif
|
||||
|
||||
#if __cplusplus >= 201703L && __has_include(<string_view>)
|
||||
#include <string_view>
|
||||
#endif
|
||||
|
||||
/** Calculates the minimum number of bytes required to store a buffer tensor with the specified type, sizes, and
|
||||
strides. The formula can be expressed as the following:
|
||||
|
||||
|
@ -239,6 +248,12 @@ namespace dml
|
|||
using std::make_unique;
|
||||
#endif
|
||||
|
||||
#if __cplusplus >= 201703L && __has_include(<string_view>)
|
||||
using StringView = std::string_view;
|
||||
#else
|
||||
using StringView = const std::string&;
|
||||
#endif
|
||||
|
||||
#if __cpp_exceptions
|
||||
#if DMLX_USE_WIL
|
||||
#define DMLX_THROW_IF_FAILED(_hr) THROW_IF_FAILED(_hr)
|
||||
|
@ -514,6 +529,8 @@ namespace dml
|
|||
|
||||
// The inputs to this node
|
||||
std::vector<NodeOutput*> inputs;
|
||||
|
||||
std::string name;
|
||||
};
|
||||
|
||||
// Used for representing reshapes and type punning
|
||||
|
@ -589,6 +606,25 @@ namespace dml
|
|||
return m_device.Get();
|
||||
}
|
||||
|
||||
void PushName(StringView name)
|
||||
{
|
||||
m_nameSubLengths.push(m_name.size());
|
||||
if (!m_name.empty())
|
||||
{
|
||||
m_name += "_";
|
||||
}
|
||||
m_name += name;
|
||||
}
|
||||
|
||||
void PopName()
|
||||
{
|
||||
if (!m_nameSubLengths.empty())
|
||||
{
|
||||
m_name.resize(m_nameSubLengths.top());
|
||||
m_nameSubLengths.pop();
|
||||
}
|
||||
}
|
||||
|
||||
void SetTensorPolicy(TensorPolicy policy) { m_tensorPolicy = std::move(policy); }
|
||||
const TensorPolicy& GetTensorPolicy() const { return m_tensorPolicy; }
|
||||
TensorPolicy& GetTensorPolicy() { return m_tensorPolicy; }
|
||||
|
@ -608,6 +644,9 @@ namespace dml
|
|||
std::vector<OperatorNode> m_operatorNodes;
|
||||
std::vector<ReinterpretNode> m_reinterpretNodes;
|
||||
std::deque<NodeOutput> m_nodeOutputs; // deque doesn't invalidate references to elements when it resizes
|
||||
|
||||
std::string m_name;
|
||||
std::stack<size_t> m_nameSubLengths;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
@ -635,6 +674,22 @@ namespace dml
|
|||
detail::NodeOutput* m_nodeOutput; // weak; this is owned by the GraphBuilder
|
||||
};
|
||||
|
||||
class NameScope
|
||||
{
|
||||
public:
|
||||
detail::GraphBuilder* m_builder = nullptr;
|
||||
|
||||
NameScope(detail::GraphBuilder* builder, StringView name) : m_builder(builder)
|
||||
{
|
||||
if (m_builder) m_builder->PushName(name);
|
||||
}
|
||||
|
||||
~NameScope()
|
||||
{
|
||||
if (m_builder) m_builder->PopName();
|
||||
}
|
||||
};
|
||||
|
||||
class Graph
|
||||
{
|
||||
public:
|
||||
|
@ -651,6 +706,11 @@ namespace dml
|
|||
const TensorPolicy& GetTensorPolicy() const { return m_graphBuilder->GetTensorPolicy(); }
|
||||
TensorPolicy& GetTensorPolicy() { return m_graphBuilder->GetTensorPolicy(); }
|
||||
|
||||
NameScope CreateNameScope(StringView name) { return NameScope(m_graphBuilder.get(), name); }
|
||||
|
||||
void PushName(StringView name) { m_graphBuilder->PushName(name); }
|
||||
void PopName() { m_graphBuilder->PopName(); }
|
||||
|
||||
Microsoft::WRL::ComPtr<IDMLCompiledOperator> Compile(
|
||||
DML_EXECUTION_FLAGS flags,
|
||||
Span<const Expression> outputs,
|
||||
|
@ -4114,6 +4174,10 @@ namespace dml
|
|||
OperatorNode node = {};
|
||||
node.op = std::move(op);
|
||||
node.inputs.assign(inputs.begin(), inputs.end());
|
||||
if (!m_name.empty())
|
||||
{
|
||||
node.name = m_name;
|
||||
}
|
||||
|
||||
uint32_t index = static_cast<uint32_t>(m_operatorNodes.size());
|
||||
m_operatorNodes.push_back(std::move(node));
|
||||
|
@ -4152,7 +4216,8 @@ namespace dml
|
|||
for (const OperatorNode& node : m_operatorNodes)
|
||||
{
|
||||
uint32_t nodeIndex = static_cast<uint32_t>(desc.nodes.size());
|
||||
desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get() });
|
||||
|
||||
desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get(), (!node.name.empty() ? node.name.c_str() : nullptr) });
|
||||
|
||||
// Walk through each of this node's inputs and add it as an edge
|
||||
const uint32_t inputCount = static_cast<uint32_t>(node.inputs.size());
|
||||
|
|
Загрузка…
Ссылка в новой задаче