User/chrila/enable graph compile (#606)
* Add Graph option * Update optional tensor logic * Move json parser logic for DmlCompileType * Update version * Update DmlCompileType namespace, json def, and updated Guid.md * update spacing --------- Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com>
This commit is contained in:
Родитель
72ad224f0b
Коммит
61a1a5085a
|
@ -468,6 +468,28 @@ Take note of the few odd cases that don't follow the usual rule exactly:
|
|||
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
|
||||
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.
|
||||
|
||||
### DirectML Compile Op vs Graph (dmlCompileType)
|
||||
Enum dmlCompileType configures whether a defined DirectML operator uses IDMLDevice::CompileOperator or the operator is inserted into DML_GRAPH_DESC and compiled using IDMLDevice1::CompileGraph.
|
||||
|
||||
| Enums for dmlCompileType | Description |
|
||||
| ------------------------------------------------ | ------------------------------------------------------------------------- |
|
||||
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Uses IDMLDevice::CompileOperator for defined operator |
|
||||
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Inserts Operator into a DML_GRAPH_DESC and uses IDMLDevice1::CompileGraph |
|
||||
|
||||
Syntax:
|
||||
|
||||
```json
|
||||
"dmlOperator":
|
||||
{
|
||||
"type": "DML_OPERATOR_*",
|
||||
"dmlCompileType": "DmlCompileGraph",
|
||||
"Desc": { ... }
|
||||
}
|
||||
```
|
||||
|
||||
See full example in [dml_gemm_graph.json](../models/dml_gemm_graph.json).
|
||||
|
||||
|
||||
### DML_TENSOR_DESC
|
||||
|
||||
Since tensor descs are so common, the JSON parser provides default values for most fields.
|
||||
|
|
|
@ -76,6 +76,15 @@
|
|||
]
|
||||
},
|
||||
|
||||
"dmlCompileType":
|
||||
{
|
||||
"enum":
|
||||
[
|
||||
"DmlCompileOp",
|
||||
"DmlCompileGraph"
|
||||
]
|
||||
},
|
||||
|
||||
"arrayInitializer":
|
||||
{
|
||||
"type": "array",
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
{
|
||||
"$schema": "./_schema.json",
|
||||
|
||||
"resources":
|
||||
{
|
||||
"A": {
|
||||
"initialValuesDataType": "FLOAT32",
|
||||
"initialValues": { "valueCount": 1024, "value": 1 }
|
||||
},
|
||||
"B": {
|
||||
"initialValuesDataType": "FLOAT32",
|
||||
"initialValues": { "valueCount": 1024, "value": 1 }
|
||||
},
|
||||
"output": {
|
||||
"initialValuesDataType": "FLOAT32",
|
||||
"initialValues": { "valueCount": 1024, "value": 1 }
|
||||
}
|
||||
},
|
||||
|
||||
"dispatchables":
|
||||
{
|
||||
"gemm":
|
||||
{
|
||||
"type": "DML_OPERATOR_GEMM",
|
||||
"desc":
|
||||
{
|
||||
"ATensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
|
||||
"BTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32], "Flags": "DML_TENSOR_FLAG_OWNED_BY_DML" },
|
||||
"OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
|
||||
"TransA": "DML_MATRIX_TRANSFORM_NONE",
|
||||
"TransB": "DML_MATRIX_TRANSFORM_NONE",
|
||||
"Alpha": 1.0,
|
||||
"Beta": 1.0
|
||||
},
|
||||
"dmlCompileType": "DmlCompileGraph",
|
||||
"executionFlags": "DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION",
|
||||
"bindings":
|
||||
{
|
||||
"BTensor": "B"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
"commands":
|
||||
[
|
||||
{
|
||||
"type": "dispatch",
|
||||
"dispatchable": "gemm",
|
||||
"bindings":
|
||||
{
|
||||
"ATensor": "A",
|
||||
"OutputTensor": "output"
|
||||
}
|
||||
},
|
||||
{ "type": "print", "resource": "output" }
|
||||
]
|
||||
}
|
|
@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
|
|||
std::string_view name,
|
||||
std::shared_ptr<Device> device,
|
||||
const Model::DmlDispatchableDesc& desc,
|
||||
const Dispatchable::Bindings& initBindings
|
||||
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings))
|
||||
const Dispatchable::Bindings& initBindings,
|
||||
IDxDispatchLogger* logger
|
||||
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
|
||||
{
|
||||
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
|
||||
}
|
||||
|
@ -28,7 +29,8 @@ void FillBindingData(
|
|||
const Dispatchable::Bindings* initializeBindings,
|
||||
const Dispatchable::Bindings* executeBindings,
|
||||
BindingData& bindingData,
|
||||
bool bindingForInitialization = false)
|
||||
bool bindingForInitialization,
|
||||
Model::DmlDispatchableDesc::DmlCompileType compileType)
|
||||
{
|
||||
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
|
||||
|
||||
|
@ -47,22 +49,23 @@ void FillBindingData(
|
|||
|
||||
if (bindingIterator == bindings.end())
|
||||
{
|
||||
if (bindPoints[i].required && !bindingForInitialization)
|
||||
{
|
||||
if (!initializeBindings || initializeBindings->find(bindPointName) == initializeBindings->end())
|
||||
{
|
||||
throw std::invalid_argument(fmt::format("Nothing bound for required tensor '{}'.", bindPointName));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
|
||||
{
|
||||
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
|
||||
bindingData.bufferBindings[bufferIndex].Offset = 0;
|
||||
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
|
||||
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
|
||||
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
|
||||
bufferIndex++;
|
||||
if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i].requiredBinding)
|
||||
{
|
||||
// Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
|
||||
bindingData.bindingDescs.pop_back();
|
||||
bindingData.bufferBindings.pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
|
||||
bindingData.bufferBindings[bufferIndex].Offset = 0;
|
||||
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
|
||||
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
|
||||
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
|
||||
bufferIndex++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -103,11 +106,82 @@ void FillBindingData(
|
|||
|
||||
void DmlDispatchable::Initialize()
|
||||
{
|
||||
THROW_IF_FAILED(m_device->DML()->CompileOperator(
|
||||
m_operator.Get(),
|
||||
m_desc.executionFlags,
|
||||
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
|
||||
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
|
||||
if(m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
|
||||
{
|
||||
m_logger->LogInfo("Compile Op");
|
||||
THROW_IF_FAILED(m_device->DML()->CompileOperator(
|
||||
m_operator.Get(),
|
||||
m_desc.executionFlags,
|
||||
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
|
||||
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
|
||||
}
|
||||
else
|
||||
{
|
||||
m_logger->LogInfo("Compiling op using IDMLDevice1::CompileGraph");
|
||||
DML_GRAPH_DESC dmlGraphDesc = {};
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
|
||||
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
|
||||
DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
|
||||
DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};
|
||||
|
||||
nodeDesc.Operator = m_operator.Get();
|
||||
nodeDesc.Name = m_name.c_str();
|
||||
|
||||
{
|
||||
dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
|
||||
dmlGraphNodeDesc.Desc = &nodeDesc;
|
||||
}
|
||||
|
||||
dmlInputGraphEdges.resize(m_desc.bindPoints.inputs.size());
|
||||
for( size_t i = 0; i < m_desc.bindPoints.inputs.size(); i++)
|
||||
{
|
||||
if (m_desc.bindPoints.inputs[i].requiredBinding)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC desc = {};
|
||||
desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
|
||||
desc.ToNodeIndex = 0;
|
||||
desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
|
||||
desc.Name = m_desc.bindPoints.inputs[i].name.c_str();
|
||||
dmlInputGraphEdges[i] = desc;
|
||||
dmlInputEdges.push_back({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
|
||||
}
|
||||
}
|
||||
|
||||
dmlOutputGraphEdges.resize(m_desc.bindPoints.outputs.size());
|
||||
for( size_t i = 0; i < m_desc.bindPoints.outputs.size(); i++)
|
||||
{
|
||||
if (m_desc.bindPoints.outputs[i].requiredBinding)
|
||||
{
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
|
||||
desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
|
||||
desc.FromNodeIndex = 0;
|
||||
desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
|
||||
desc.Name = m_desc.bindPoints.outputs[i].name.c_str();
|
||||
dmlOutputGraphEdges[i] = desc;
|
||||
dmlOutputEdges.push_back({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
|
||||
}
|
||||
}
|
||||
|
||||
dmlGraphDesc.InputCount = static_cast<uint32_t>(dmlInputEdges.size());
|
||||
dmlGraphDesc.InputEdges = dmlInputEdges.data();
|
||||
dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount;
|
||||
|
||||
dmlGraphDesc.OutputCount = static_cast<uint32_t>(dmlOutputEdges.size());
|
||||
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
|
||||
dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount;
|
||||
|
||||
dmlGraphDesc.IntermediateEdgeCount = 0;
|
||||
dmlGraphDesc.IntermediateEdges = nullptr;
|
||||
|
||||
dmlGraphDesc.NodeCount = 1;
|
||||
dmlGraphDesc.Nodes = &dmlGraphNodeDesc;
|
||||
|
||||
THROW_IF_FAILED(m_device->DML()->CompileGraph(&dmlGraphDesc, m_desc.executionFlags, IID_PPV_ARGS(&m_operatorCompiled)));
|
||||
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(fmt::format("Graph_{}", m_name)).data());
|
||||
}
|
||||
|
||||
ComPtr<IDMLOperatorInitializer> initializer;
|
||||
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
|
||||
|
@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
|
|||
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
|
||||
// be bound using a separate buffer array binding.
|
||||
BindingData inputBindingData = {};
|
||||
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true);
|
||||
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true, m_desc.compileType);
|
||||
|
||||
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
|
||||
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
|
||||
|
@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
|
|||
auto bindingProps = m_operatorCompiled->GetBindingProperties();
|
||||
|
||||
BindingData inputBindingData = {};
|
||||
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData);
|
||||
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData, false, m_desc.compileType);
|
||||
|
||||
BindingData outputBindingData = {};
|
||||
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData);
|
||||
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData, false, m_desc.compileType);
|
||||
|
||||
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
|
||||
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
|
||||
|
|
|
@ -7,7 +7,8 @@ public:
|
|||
std::string_view name,
|
||||
std::shared_ptr<Device> device,
|
||||
const Model::DmlDispatchableDesc& desc,
|
||||
const Dispatchable::Bindings& initBindings);
|
||||
const Dispatchable::Bindings& initBindings,
|
||||
IDxDispatchLogger* logger);
|
||||
|
||||
void Initialize() final;
|
||||
void Bind(const Bindings& bindings, uint32_t iteration) final;
|
||||
|
@ -23,4 +24,5 @@ private:
|
|||
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
|
||||
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
|
||||
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
|
||||
Microsoft::WRL::ComPtr<IDxDispatchLogger> m_logger;
|
||||
};
|
|
@ -162,7 +162,7 @@ Executor::Executor(Model& model, std::shared_ptr<Device> device, const CommandLi
|
|||
return;
|
||||
}
|
||||
|
||||
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings);
|
||||
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings, m_logger.Get());
|
||||
}
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
|
|
|
@ -1402,11 +1402,58 @@ std::vector<Model::BufferBindingSource> ParseBindingSource(const rapidjson::Valu
|
|||
return sourceResources;
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileType(const rapidjson::Value& value)
|
||||
{
|
||||
if (value.GetType() != rapidjson::Type::kStringType)
|
||||
{
|
||||
throw std::invalid_argument("Expected a string.");
|
||||
}
|
||||
auto valueString = value.GetString();
|
||||
if (!strcmp(valueString, "DmlCompileOp")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp; }
|
||||
if (!strcmp(valueString, "DmlCompileGraph")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph; }
|
||||
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DmlCompileType.", valueString));
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, Model::DmlDispatchableDesc::DmlCompileType defaultValue)
|
||||
{
|
||||
return ParseFieldHelper<Model::DmlDispatchableDesc::DmlCompileType>(object, fieldName, required, defaultValue, [](auto& value) {
|
||||
return ParseDmlCompileType(value);
|
||||
});
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
|
||||
{
|
||||
Model::DmlDispatchableDesc desc;
|
||||
desc.desc = ParseDmlOperatorDesc(object, false, allocator);
|
||||
desc.bindPoints = GetBindPoints(*desc.desc);
|
||||
|
||||
// DirectML requires optional bindings if DML_OPERATOR_DESC declares that binding for optional operator tensors.
|
||||
// Logic is based on the Model directml Operator the tensors declared in "desc".
|
||||
auto UpdateBindingPoints = [](const rapidjson::Value& object, std::vector<Model::DmlDispatchableDesc::BindPoint>& bindPoints) {
|
||||
for (auto& bindPoint : bindPoints)
|
||||
{
|
||||
if (bindPoint.required || object.HasMember(bindPoint.name.c_str()))
|
||||
{
|
||||
bindPoint.requiredBinding = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
bindPoint.requiredBinding = false;
|
||||
}
|
||||
}};
|
||||
|
||||
auto descMember = object.FindMember("Desc");
|
||||
if (descMember == object.MemberEnd())
|
||||
{
|
||||
descMember = object.FindMember("desc");
|
||||
}
|
||||
if (descMember != object.MemberEnd())
|
||||
{
|
||||
UpdateBindingPoints(descMember->value, desc.bindPoints.inputs);
|
||||
UpdateBindingPoints(descMember->value, desc.bindPoints.outputs);
|
||||
}
|
||||
desc.compileType = ParseDmlCompileTypeField(object, "dmlCompileType", false, Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp);
|
||||
|
||||
desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);
|
||||
|
||||
auto bindingsField = object.FindMember("bindings");
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <DirectML.h>
|
||||
#include "BucketAllocator.h"
|
||||
|
||||
|
||||
class Model
|
||||
{
|
||||
public:
|
||||
|
@ -58,11 +59,17 @@ public:
|
|||
|
||||
struct DmlDispatchableDesc
|
||||
{
|
||||
enum class DmlCompileType
|
||||
{
|
||||
DmlCompileOp,
|
||||
DmlCompileGraph
|
||||
};
|
||||
struct BindPoint
|
||||
{
|
||||
std::string name;
|
||||
uint32_t resourceCount;
|
||||
bool required;
|
||||
bool requiredBinding;
|
||||
};
|
||||
|
||||
struct BindPoints
|
||||
|
@ -74,6 +81,7 @@ public:
|
|||
DML_OPERATOR_DESC* desc;
|
||||
BindPoints bindPoints;
|
||||
DML_EXECUTION_FLAGS executionFlags;
|
||||
DmlCompileType compileType;
|
||||
Bindings initBindings;
|
||||
};
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче