Support for node names in DMLX graphs (#389)

This commit is contained in:
Justin Stoecker 2023-02-15 13:58:26 -08:00 коммит произвёл GitHub
Родитель e491f5bcb1
Коммит b95f3cf77b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 66 добавлений и 1 удалений

Просмотреть файл

@ -22,6 +22,7 @@
#include <utility>
#include <type_traits>
#include <functional>
#include <stack>
#include <wrl/client.h> // For Microsoft::WRL::ComPtr
@ -37,6 +38,14 @@
#define DMLX_OPTIONAL_EXTENDED
#endif
#if __cpp_exceptions
#include <stdexcept>
#endif
#if __cplusplus >= 201703L && __has_include(<string_view>)
#include <string_view>
#endif
/** Calculates the minimum number of bytes required to store a buffer tensor with the specified type, sizes, and
strides. The formula can be expressed as the following:
@ -239,6 +248,12 @@ namespace dml
using std::make_unique;
#endif
#if __cplusplus >= 201703L && __has_include(<string_view>)
using StringView = std::string_view;
#else
using StringView = const std::string&;
#endif
#if __cpp_exceptions
#if DMLX_USE_WIL
#define DMLX_THROW_IF_FAILED(_hr) THROW_IF_FAILED(_hr)
@ -514,6 +529,8 @@ namespace dml
// The inputs to this node
std::vector<NodeOutput*> inputs;
std::string name;
};
// Used for representing reshapes and type punning
@ -589,6 +606,25 @@ namespace dml
return m_device.Get();
}
void PushName(StringView name)
{
m_nameSubLengths.push(m_name.size());
if (!m_name.empty())
{
m_name += "_";
}
m_name += name;
}
void PopName()
{
if (!m_nameSubLengths.empty())
{
m_name.resize(m_nameSubLengths.top());
m_nameSubLengths.pop();
}
}
void SetTensorPolicy(TensorPolicy policy) { m_tensorPolicy = std::move(policy); }
const TensorPolicy& GetTensorPolicy() const { return m_tensorPolicy; }
TensorPolicy& GetTensorPolicy() { return m_tensorPolicy; }
@ -608,6 +644,9 @@ namespace dml
std::vector<OperatorNode> m_operatorNodes;
std::vector<ReinterpretNode> m_reinterpretNodes;
std::deque<NodeOutput> m_nodeOutputs; // deque doesn't invalidate references to elements when it resizes
std::string m_name;
std::stack<size_t> m_nameSubLengths;
};
} // namespace detail
@ -635,6 +674,22 @@ namespace dml
detail::NodeOutput* m_nodeOutput; // weak; this is owned by the GraphBuilder
};
class NameScope
{
public:
detail::GraphBuilder* m_builder = nullptr;
NameScope(detail::GraphBuilder* builder, StringView name) : m_builder(builder)
{
if (m_builder) m_builder->PushName(name);
}
~NameScope()
{
if (m_builder) m_builder->PopName();
}
};
class Graph
{
public:
@ -651,6 +706,11 @@ namespace dml
const TensorPolicy& GetTensorPolicy() const { return m_graphBuilder->GetTensorPolicy(); }
TensorPolicy& GetTensorPolicy() { return m_graphBuilder->GetTensorPolicy(); }
NameScope CreateNameScope(StringView name) { return NameScope(m_graphBuilder.get(), name); }
void PushName(StringView name) { m_graphBuilder->PushName(name); }
void PopName() { m_graphBuilder->PopName(); }
Microsoft::WRL::ComPtr<IDMLCompiledOperator> Compile(
DML_EXECUTION_FLAGS flags,
Span<const Expression> outputs,
@ -4114,6 +4174,10 @@ namespace dml
OperatorNode node = {};
node.op = std::move(op);
node.inputs.assign(inputs.begin(), inputs.end());
if (!m_name.empty())
{
node.name = m_name;
}
uint32_t index = static_cast<uint32_t>(m_operatorNodes.size());
m_operatorNodes.push_back(std::move(node));
@ -4152,7 +4216,8 @@ namespace dml
for (const OperatorNode& node : m_operatorNodes)
{
uint32_t nodeIndex = static_cast<uint32_t>(desc.nodes.size());
desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get() });
desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get(), (!node.name.empty() ? node.name.c_str() : nullptr) });
// Walk through each of this node's inputs and add it as an edge
const uint32_t inputCount = static_cast<uint32_t>(node.inputs.size());