User/chrila/enable graph compile (#606)

* Add Graph option

* Update optional tensor logic

* Move json parser logic for DmlCompileType

* Update version

* Update DmlCompileType namespace, json def, and updated Guid.md

* update spacing

---------

Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com>
This commit is contained in:
Christian Larson 2024-07-12 15:40:57 -07:00 коммит произвёл GitHub
Родитель 72ad224f0b
Коммит 61a1a5085a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 246 добавлений и 27 удалений

Просмотреть файл

@ -468,6 +468,28 @@ Take note of the few odd cases that don't follow the usual rule exactly:
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.
### DirectML Compile Op vs Graph (dmlCompileType)
Enum dmlCompileType configures whether a defined DirectML operator uses IDMLDevice::CompileOperator or the operator is inserted into DML_GRAPH_DESC and compiled using IDMLDevice1::CompileGraph.
| Enums for dmlCompileType | Description |
| ------------------------------------------------ | ------------------------------------------------------------------------- |
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Uses IDMLDevice::CompileOperator for defined operator |
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Inserts Operator into a DML_GRAPH_DESC and uses IDMLDevice1::CompileGraph |
Syntax:
```json
"dmlOperator":
{
"type": "DML_OPERATOR_*",
"dmlCompileType": "DmlCompileGraph",
"Desc": { ... }
}
```
See full example in [dml_gemm_graph.json](../models/dml_gemm_graph.json).
### DML_TENSOR_DESC
Since tensor descs are so common, the JSON parser provides default values for most fields.

Просмотреть файл

@ -76,6 +76,15 @@
]
},
"dmlCompileType":
{
"enum":
[
"DmlCompileOp",
"DmlCompileGraph"
]
},
"arrayInitializer":
{
"type": "array",

Просмотреть файл

@ -0,0 +1,57 @@
{
"$schema": "./_schema.json",
"resources":
{
"A": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
},
"B": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
},
"output": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
}
},
"dispatchables":
{
"gemm":
{
"type": "DML_OPERATOR_GEMM",
"desc":
{
"ATensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
"BTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32], "Flags": "DML_TENSOR_FLAG_OWNED_BY_DML" },
"OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
"TransA": "DML_MATRIX_TRANSFORM_NONE",
"TransB": "DML_MATRIX_TRANSFORM_NONE",
"Alpha": 1.0,
"Beta": 1.0
},
"dmlCompileType": "DmlCompileGraph",
"executionFlags": "DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION",
"bindings":
{
"BTensor": "B"
}
}
},
"commands":
[
{
"type": "dispatch",
"dispatchable": "gemm",
"bindings":
{
"ATensor": "A",
"OutputTensor": "output"
}
},
{ "type": "print", "resource": "output" }
]
}

Просмотреть файл

@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
std::string_view name,
std::shared_ptr<Device> device,
const Model::DmlDispatchableDesc& desc,
const Dispatchable::Bindings& initBindings
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings))
const Dispatchable::Bindings& initBindings,
IDxDispatchLogger* logger
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
{
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
}
@ -28,7 +29,8 @@ void FillBindingData(
const Dispatchable::Bindings* initializeBindings,
const Dispatchable::Bindings* executeBindings,
BindingData& bindingData,
bool bindingForInitialization = false)
bool bindingForInitialization,
Model::DmlDispatchableDesc::DmlCompileType compileType)
{
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
@ -47,22 +49,23 @@ void FillBindingData(
if (bindingIterator == bindings.end())
{
if (bindPoints[i].required && !bindingForInitialization)
{
if (!initializeBindings || initializeBindings->find(bindPointName) == initializeBindings->end())
{
throw std::invalid_argument(fmt::format("Nothing bound for required tensor '{}'.", bindPointName));
}
}
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
{
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
bindingData.bufferBindings[bufferIndex].Offset = 0;
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
bufferIndex++;
if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i].requiredBinding)
{
// Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
bindingData.bindingDescs.pop_back();
bindingData.bufferBindings.pop_back();
}
else
{
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
bindingData.bufferBindings[bufferIndex].Offset = 0;
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
bufferIndex++;
}
}
}
else
@ -103,11 +106,82 @@ void FillBindingData(
void DmlDispatchable::Initialize()
{
THROW_IF_FAILED(m_device->DML()->CompileOperator(
m_operator.Get(),
m_desc.executionFlags,
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
if(m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
{
m_logger->LogInfo("Compile Op");
THROW_IF_FAILED(m_device->DML()->CompileOperator(
m_operator.Get(),
m_desc.executionFlags,
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
}
else
{
m_logger->LogInfo("Compiling op using IDMLDevice1::CompileGraph");
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};
nodeDesc.Operator = m_operator.Get();
nodeDesc.Name = m_name.c_str();
{
dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
dmlGraphNodeDesc.Desc = &nodeDesc;
}
dmlInputGraphEdges.resize(m_desc.bindPoints.inputs.size());
for( size_t i = 0; i < m_desc.bindPoints.inputs.size(); i++)
{
if (m_desc.bindPoints.inputs[i].requiredBinding)
{
DML_INPUT_GRAPH_EDGE_DESC desc = {};
desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
desc.ToNodeIndex = 0;
desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
desc.Name = m_desc.bindPoints.inputs[i].name.c_str();
dmlInputGraphEdges[i] = desc;
dmlInputEdges.push_back({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
}
}
dmlOutputGraphEdges.resize(m_desc.bindPoints.outputs.size());
for( size_t i = 0; i < m_desc.bindPoints.outputs.size(); i++)
{
if (m_desc.bindPoints.outputs[i].requiredBinding)
{
DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
desc.FromNodeIndex = 0;
desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
desc.Name = m_desc.bindPoints.outputs[i].name.c_str();
dmlOutputGraphEdges[i] = desc;
dmlOutputEdges.push_back({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
}
}
dmlGraphDesc.InputCount = static_cast<uint32_t>(dmlInputEdges.size());
dmlGraphDesc.InputEdges = dmlInputEdges.data();
dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount;
dmlGraphDesc.OutputCount = static_cast<uint32_t>(dmlOutputEdges.size());
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount;
dmlGraphDesc.IntermediateEdgeCount = 0;
dmlGraphDesc.IntermediateEdges = nullptr;
dmlGraphDesc.NodeCount = 1;
dmlGraphDesc.Nodes = &dmlGraphNodeDesc;
THROW_IF_FAILED(m_device->DML()->CompileGraph(&dmlGraphDesc, m_desc.executionFlags, IID_PPV_ARGS(&m_operatorCompiled)));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(fmt::format("Graph_{}", m_name)).data());
}
ComPtr<IDMLOperatorInitializer> initializer;
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
// be bound using a separate buffer array binding.
BindingData inputBindingData = {};
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true);
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true, m_desc.compileType);
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
auto bindingProps = m_operatorCompiled->GetBindingProperties();
BindingData inputBindingData = {};
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData);
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData, false, m_desc.compileType);
BindingData outputBindingData = {};
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData);
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData, false, m_desc.compileType);
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;

Просмотреть файл

@ -7,7 +7,8 @@ public:
std::string_view name,
std::shared_ptr<Device> device,
const Model::DmlDispatchableDesc& desc,
const Dispatchable::Bindings& initBindings);
const Dispatchable::Bindings& initBindings,
IDxDispatchLogger* logger);
void Initialize() final;
void Bind(const Bindings& bindings, uint32_t iteration) final;
@ -23,4 +24,5 @@ private:
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
Microsoft::WRL::ComPtr<IDxDispatchLogger> m_logger;
};

Просмотреть файл

@ -162,7 +162,7 @@ Executor::Executor(Model& model, std::shared_ptr<Device> device, const CommandLi
return;
}
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings);
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings, m_logger.Get());
}
}
catch(const std::exception& e)

Просмотреть файл

@ -1402,11 +1402,58 @@ std::vector<Model::BufferBindingSource> ParseBindingSource(const rapidjson::Valu
return sourceResources;
}
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileType(const rapidjson::Value& value)
{
if (value.GetType() != rapidjson::Type::kStringType)
{
throw std::invalid_argument("Expected a string.");
}
auto valueString = value.GetString();
if (!strcmp(valueString, "DmlCompileOp")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp; }
if (!strcmp(valueString, "DmlCompileGraph")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DmlCompileType.", valueString));
}
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, Model::DmlDispatchableDesc::DmlCompileType defaultValue)
{
return ParseFieldHelper<Model::DmlDispatchableDesc::DmlCompileType>(object, fieldName, required, defaultValue, [](auto& value) {
return ParseDmlCompileType(value);
});
}
Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
{
Model::DmlDispatchableDesc desc;
desc.desc = ParseDmlOperatorDesc(object, false, allocator);
desc.bindPoints = GetBindPoints(*desc.desc);
// DirectML requires optional bindings if DML_OPERATOR_DESC declares that binding for optional operator tensors.
// Logic is based on the Model directml Operator the tensors declared in "desc".
auto UpdateBindingPoints = [](const rapidjson::Value& object, std::vector<Model::DmlDispatchableDesc::BindPoint>& bindPoints) {
for (auto& bindPoint : bindPoints)
{
if (bindPoint.required || object.HasMember(bindPoint.name.c_str()))
{
bindPoint.requiredBinding = true;
}
else
{
bindPoint.requiredBinding = false;
}
}};
auto descMember = object.FindMember("Desc");
if (descMember == object.MemberEnd())
{
descMember = object.FindMember("desc");
}
if (descMember != object.MemberEnd())
{
UpdateBindingPoints(descMember->value, desc.bindPoints.inputs);
UpdateBindingPoints(descMember->value, desc.bindPoints.outputs);
}
desc.compileType = ParseDmlCompileTypeField(object, "dmlCompileType", false, Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp);
desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);
auto bindingsField = object.FindMember("bindings");

Просмотреть файл

@ -10,6 +10,7 @@
#include <DirectML.h>
#include "BucketAllocator.h"
class Model
{
public:
@ -58,11 +59,17 @@ public:
struct DmlDispatchableDesc
{
enum class DmlCompileType
{
DmlCompileOp,
DmlCompileGraph
};
struct BindPoint
{
std::string name;
uint32_t resourceCount;
bool required;
bool requiredBinding;
};
struct BindPoints
@ -74,6 +81,7 @@ public:
DML_OPERATOR_DESC* desc;
BindPoints bindPoints;
DML_EXECUTION_FLAGS executionFlags;
DmlCompileType compileType;
Bindings initBindings;
};