User/chrila/enable graph compile (#606)

* Add Graph option

* Update optional tensor logic

* Move json parser logic for DmlCompileType

* Update version

* Update DmlCompileType namespace, json def, and updated Guid.md

* update spacing

---------

Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com>
This commit is contained in:
Christian Larson 2024-07-12 15:40:57 -07:00 коммит произвёл GitHub
Родитель 72ad224f0b
Коммит 61a1a5085a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 246 добавлений и 27 удалений

Просмотреть файл

@ -468,6 +468,28 @@ Take note of the few odd cases that don't follow the usual rule exactly:
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`. - Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`. - Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.
### DirectML Compile Op vs Graph (dmlCompileType)
Enum dmlCompileType configures whether a defined DirectML operator uses IDMLDevice::CompileOperator or the operator is inserted into DML_GRAPH_DESC and compiled using IDMLDevice1::CompileGraph.
| Enums for dmlCompileType | Description |
| ------------------------------------------------ | ------------------------------------------------------------------------- |
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Uses IDMLDevice::CompileOperator for defined operator |
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Inserts Operator into a DML_GRAPH_DESC and uses IDMLDevice1::CompileGraph |
Syntax:
```json
"dmlOperator":
{
"type": "DML_OPERATOR_*",
"dmlCompileType": "DmlCompileGraph",
"Desc": { ... }
}
```
See full example in [dml_gemm_graph.json](../models/dml_gemm_graph.json).
### DML_TENSOR_DESC ### DML_TENSOR_DESC
Since tensor descs are so common, the JSON parser provides default values for most fields. Since tensor descs are so common, the JSON parser provides default values for most fields.

Просмотреть файл

@ -76,6 +76,15 @@
] ]
}, },
"dmlCompileType":
{
"enum":
[
"DmlCompileOp",
"DmlCompileGraph"
]
},
"arrayInitializer": "arrayInitializer":
{ {
"type": "array", "type": "array",

Просмотреть файл

@ -0,0 +1,57 @@
{
"$schema": "./_schema.json",
"resources":
{
"A": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
},
"B": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
},
"output": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
}
},
"dispatchables":
{
"gemm":
{
"type": "DML_OPERATOR_GEMM",
"desc":
{
"ATensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
"BTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32], "Flags": "DML_TENSOR_FLAG_OWNED_BY_DML" },
"OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
"TransA": "DML_MATRIX_TRANSFORM_NONE",
"TransB": "DML_MATRIX_TRANSFORM_NONE",
"Alpha": 1.0,
"Beta": 1.0
},
"dmlCompileType": "DmlCompileGraph",
"executionFlags": "DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION",
"bindings":
{
"BTensor": "B"
}
}
},
"commands":
[
{
"type": "dispatch",
"dispatchable": "gemm",
"bindings":
{
"ATensor": "A",
"OutputTensor": "output"
}
},
{ "type": "print", "resource": "output" }
]
}

Просмотреть файл

@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
std::string_view name, std::string_view name,
std::shared_ptr<Device> device, std::shared_ptr<Device> device,
const Model::DmlDispatchableDesc& desc, const Model::DmlDispatchableDesc& desc,
const Dispatchable::Bindings& initBindings const Dispatchable::Bindings& initBindings,
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)) IDxDispatchLogger* logger
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
{ {
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator))); THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
} }
@ -28,7 +29,8 @@ void FillBindingData(
const Dispatchable::Bindings* initializeBindings, const Dispatchable::Bindings* initializeBindings,
const Dispatchable::Bindings* executeBindings, const Dispatchable::Bindings* executeBindings,
BindingData& bindingData, BindingData& bindingData,
bool bindingForInitialization = false) bool bindingForInitialization,
Model::DmlDispatchableDesc::DmlCompileType compileType)
{ {
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings; const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
@ -47,22 +49,23 @@ void FillBindingData(
if (bindingIterator == bindings.end()) if (bindingIterator == bindings.end())
{ {
if (bindPoints[i].required && !bindingForInitialization)
{
if (!initializeBindings || initializeBindings->find(bindPointName) == initializeBindings->end())
{
throw std::invalid_argument(fmt::format("Nothing bound for required tensor '{}'.", bindPointName));
}
}
for (size_t j = 0; j < bindPoints[i].resourceCount; j++) for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
{ {
bindingData.bufferBindings[bufferIndex].Buffer = nullptr; if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i].requiredBinding)
bindingData.bufferBindings[bufferIndex].Offset = 0; {
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0; // Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE; bindingData.bindingDescs.pop_back();
bindingData.bindingDescs[bufferIndex].Desc = nullptr; bindingData.bufferBindings.pop_back();
bufferIndex++; }
else
{
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
bindingData.bufferBindings[bufferIndex].Offset = 0;
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
bufferIndex++;
}
} }
} }
else else
@ -103,11 +106,82 @@ void FillBindingData(
void DmlDispatchable::Initialize() void DmlDispatchable::Initialize()
{ {
THROW_IF_FAILED(m_device->DML()->CompileOperator( if(m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
m_operator.Get(), {
m_desc.executionFlags, m_logger->LogInfo("Compile Op");
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf()))); THROW_IF_FAILED(m_device->DML()->CompileOperator(
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data()); m_operator.Get(),
m_desc.executionFlags,
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
}
else
{
m_logger->LogInfo("Compiling op using IDMLDevice1::CompileGraph");
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};
nodeDesc.Operator = m_operator.Get();
nodeDesc.Name = m_name.c_str();
{
dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
dmlGraphNodeDesc.Desc = &nodeDesc;
}
dmlInputGraphEdges.resize(m_desc.bindPoints.inputs.size());
for( size_t i = 0; i < m_desc.bindPoints.inputs.size(); i++)
{
if (m_desc.bindPoints.inputs[i].requiredBinding)
{
DML_INPUT_GRAPH_EDGE_DESC desc = {};
desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
desc.ToNodeIndex = 0;
desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
desc.Name = m_desc.bindPoints.inputs[i].name.c_str();
dmlInputGraphEdges[i] = desc;
dmlInputEdges.push_back({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
}
}
dmlOutputGraphEdges.resize(m_desc.bindPoints.outputs.size());
for( size_t i = 0; i < m_desc.bindPoints.outputs.size(); i++)
{
if (m_desc.bindPoints.outputs[i].requiredBinding)
{
DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
desc.FromNodeIndex = 0;
desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
desc.Name = m_desc.bindPoints.outputs[i].name.c_str();
dmlOutputGraphEdges[i] = desc;
dmlOutputEdges.push_back({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
}
}
dmlGraphDesc.InputCount = static_cast<uint32_t>(dmlInputEdges.size());
dmlGraphDesc.InputEdges = dmlInputEdges.data();
dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount;
dmlGraphDesc.OutputCount = static_cast<uint32_t>(dmlOutputEdges.size());
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount;
dmlGraphDesc.IntermediateEdgeCount = 0;
dmlGraphDesc.IntermediateEdges = nullptr;
dmlGraphDesc.NodeCount = 1;
dmlGraphDesc.Nodes = &dmlGraphNodeDesc;
THROW_IF_FAILED(m_device->DML()->CompileGraph(&dmlGraphDesc, m_desc.executionFlags, IID_PPV_ARGS(&m_operatorCompiled)));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(fmt::format("Graph_{}", m_name)).data());
}
ComPtr<IDMLOperatorInitializer> initializer; ComPtr<IDMLOperatorInitializer> initializer;
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() }; IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must // Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
// be bound using a separate buffer array binding. // be bound using a separate buffer array binding.
BindingData inputBindingData = {}; BindingData inputBindingData = {};
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true); FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true, m_desc.compileType);
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {}; DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max()) if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
auto bindingProps = m_operatorCompiled->GetBindingProperties(); auto bindingProps = m_operatorCompiled->GetBindingProperties();
BindingData inputBindingData = {}; BindingData inputBindingData = {};
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData); FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData, false, m_desc.compileType);
BindingData outputBindingData = {}; BindingData outputBindingData = {};
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData); FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData, false, m_desc.compileType);
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {}; D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;

Просмотреть файл

@ -7,7 +7,8 @@ public:
std::string_view name, std::string_view name,
std::shared_ptr<Device> device, std::shared_ptr<Device> device,
const Model::DmlDispatchableDesc& desc, const Model::DmlDispatchableDesc& desc,
const Dispatchable::Bindings& initBindings); const Dispatchable::Bindings& initBindings,
IDxDispatchLogger* logger);
void Initialize() final; void Initialize() final;
void Bind(const Bindings& bindings, uint32_t iteration) final; void Bind(const Bindings& bindings, uint32_t iteration) final;
@ -23,4 +24,5 @@ private:
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer; Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable; Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap; Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
Microsoft::WRL::ComPtr<IDxDispatchLogger> m_logger;
}; };

Просмотреть файл

@ -162,7 +162,7 @@ Executor::Executor(Model& model, std::shared_ptr<Device> device, const CommandLi
return; return;
} }
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings); m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings, m_logger.Get());
} }
} }
catch(const std::exception& e) catch(const std::exception& e)

Просмотреть файл

@ -1402,11 +1402,58 @@ std::vector<Model::BufferBindingSource> ParseBindingSource(const rapidjson::Valu
return sourceResources; return sourceResources;
} }
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileType(const rapidjson::Value& value)
{
if (value.GetType() != rapidjson::Type::kStringType)
{
throw std::invalid_argument("Expected a string.");
}
auto valueString = value.GetString();
if (!strcmp(valueString, "DmlCompileOp")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp; }
if (!strcmp(valueString, "DmlCompileGraph")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DmlCompileType.", valueString));
}
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, Model::DmlDispatchableDesc::DmlCompileType defaultValue)
{
return ParseFieldHelper<Model::DmlDispatchableDesc::DmlCompileType>(object, fieldName, required, defaultValue, [](auto& value) {
return ParseDmlCompileType(value);
});
}
Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator) Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
{ {
Model::DmlDispatchableDesc desc; Model::DmlDispatchableDesc desc;
desc.desc = ParseDmlOperatorDesc(object, false, allocator); desc.desc = ParseDmlOperatorDesc(object, false, allocator);
desc.bindPoints = GetBindPoints(*desc.desc); desc.bindPoints = GetBindPoints(*desc.desc);
// DirectML requires optional bindings if DML_OPERATOR_DESC declares that binding for optional operator tensors.
// Logic is based on the Model directml Operator the tensors declared in "desc".
auto UpdateBindingPoints = [](const rapidjson::Value& object, std::vector<Model::DmlDispatchableDesc::BindPoint>& bindPoints) {
for (auto& bindPoint : bindPoints)
{
if (bindPoint.required || object.HasMember(bindPoint.name.c_str()))
{
bindPoint.requiredBinding = true;
}
else
{
bindPoint.requiredBinding = false;
}
}};
auto descMember = object.FindMember("Desc");
if (descMember == object.MemberEnd())
{
descMember = object.FindMember("desc");
}
if (descMember != object.MemberEnd())
{
UpdateBindingPoints(descMember->value, desc.bindPoints.inputs);
UpdateBindingPoints(descMember->value, desc.bindPoints.outputs);
}
desc.compileType = ParseDmlCompileTypeField(object, "dmlCompileType", false, Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp);
desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE); desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);
auto bindingsField = object.FindMember("bindings"); auto bindingsField = object.FindMember("bindings");

Просмотреть файл

@ -10,6 +10,7 @@
#include <DirectML.h> #include <DirectML.h>
#include "BucketAllocator.h" #include "BucketAllocator.h"
class Model class Model
{ {
public: public:
@ -58,11 +59,17 @@ public:
struct DmlDispatchableDesc struct DmlDispatchableDesc
{ {
enum class DmlCompileType
{
DmlCompileOp,
DmlCompileGraph
};
struct BindPoint struct BindPoint
{ {
std::string name; std::string name;
uint32_t resourceCount; uint32_t resourceCount;
bool required; bool required;
bool requiredBinding;
}; };
struct BindPoints struct BindPoints
@ -74,6 +81,7 @@ public:
DML_OPERATOR_DESC* desc; DML_OPERATOR_DESC* desc;
BindPoints bindPoints; BindPoints bindPoints;
DML_EXECUTION_FLAGS executionFlags; DML_EXECUTION_FLAGS executionFlags;
DmlCompileType compileType;
Bindings initBindings; Bindings initBindings;
}; };