User/chrila/enable graph compile (#606)
* Add Graph option * Update optional tensor logic * Move json parser logic for DmlCompileType * Update version * Update DmlCompileType namespace, json def, and updated Guid.md * update spacing --------- Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com>
This commit is contained in:
Родитель
72ad224f0b
Коммит
61a1a5085a
|
@ -468,6 +468,28 @@ Take note of the few odd cases that don't follow the usual rule exactly:
|
||||||
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
|
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
|
||||||
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.
|
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.
|
||||||
|
|
||||||
|
### DirectML Compile Op vs Graph (dmlCompileType)
|
||||||
|
Enum dmlCompileType configures whether a defined DirectML operator uses IDMLDevice::CompileOperator or the operator is inserted into DML_GRAPH_DESC and compiled using IDMLDevice1::CompileGraph.
|
||||||
|
|
||||||
|
| Enums for dmlCompileType | Description |
|
||||||
|
| ------------------------------------------------ | ------------------------------------------------------------------------- |
|
||||||
|
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Uses IDMLDevice::CompileOperator for defined operator |
|
||||||
|
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Inserts Operator into a DML_GRAPH_DESC and uses IDMLDevice1::CompileGraph |
|
||||||
|
|
||||||
|
Syntax:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"dmlOperator":
|
||||||
|
{
|
||||||
|
"type": "DML_OPERATOR_*",
|
||||||
|
"dmlCompileType": "DmlCompileGraph",
|
||||||
|
"Desc": { ... }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
See full example in [dml_gemm_graph.json](../models/dml_gemm_graph.json).
|
||||||
|
|
||||||
|
|
||||||
### DML_TENSOR_DESC
|
### DML_TENSOR_DESC
|
||||||
|
|
||||||
Since tensor descs are so common, the JSON parser provides default values for most fields.
|
Since tensor descs are so common, the JSON parser provides default values for most fields.
|
||||||
|
|
|
@ -76,6 +76,15 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"dmlCompileType":
|
||||||
|
{
|
||||||
|
"enum":
|
||||||
|
[
|
||||||
|
"DmlCompileOp",
|
||||||
|
"DmlCompileGraph"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
|
||||||
"arrayInitializer":
|
"arrayInitializer":
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
{
|
||||||
|
"$schema": "./_schema.json",
|
||||||
|
|
||||||
|
"resources":
|
||||||
|
{
|
||||||
|
"A": {
|
||||||
|
"initialValuesDataType": "FLOAT32",
|
||||||
|
"initialValues": { "valueCount": 1024, "value": 1 }
|
||||||
|
},
|
||||||
|
"B": {
|
||||||
|
"initialValuesDataType": "FLOAT32",
|
||||||
|
"initialValues": { "valueCount": 1024, "value": 1 }
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"initialValuesDataType": "FLOAT32",
|
||||||
|
"initialValues": { "valueCount": 1024, "value": 1 }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"dispatchables":
|
||||||
|
{
|
||||||
|
"gemm":
|
||||||
|
{
|
||||||
|
"type": "DML_OPERATOR_GEMM",
|
||||||
|
"desc":
|
||||||
|
{
|
||||||
|
"ATensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
|
||||||
|
"BTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32], "Flags": "DML_TENSOR_FLAG_OWNED_BY_DML" },
|
||||||
|
"OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
|
||||||
|
"TransA": "DML_MATRIX_TRANSFORM_NONE",
|
||||||
|
"TransB": "DML_MATRIX_TRANSFORM_NONE",
|
||||||
|
"Alpha": 1.0,
|
||||||
|
"Beta": 1.0
|
||||||
|
},
|
||||||
|
"dmlCompileType": "DmlCompileGraph",
|
||||||
|
"executionFlags": "DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION",
|
||||||
|
"bindings":
|
||||||
|
{
|
||||||
|
"BTensor": "B"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"commands":
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "dispatch",
|
||||||
|
"dispatchable": "gemm",
|
||||||
|
"bindings":
|
||||||
|
{
|
||||||
|
"ATensor": "A",
|
||||||
|
"OutputTensor": "output"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ "type": "print", "resource": "output" }
|
||||||
|
]
|
||||||
|
}
|
|
@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
|
||||||
std::string_view name,
|
std::string_view name,
|
||||||
std::shared_ptr<Device> device,
|
std::shared_ptr<Device> device,
|
||||||
const Model::DmlDispatchableDesc& desc,
|
const Model::DmlDispatchableDesc& desc,
|
||||||
const Dispatchable::Bindings& initBindings
|
const Dispatchable::Bindings& initBindings,
|
||||||
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings))
|
IDxDispatchLogger* logger
|
||||||
|
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
|
||||||
{
|
{
|
||||||
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
|
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
|
||||||
}
|
}
|
||||||
|
@ -28,7 +29,8 @@ void FillBindingData(
|
||||||
const Dispatchable::Bindings* initializeBindings,
|
const Dispatchable::Bindings* initializeBindings,
|
||||||
const Dispatchable::Bindings* executeBindings,
|
const Dispatchable::Bindings* executeBindings,
|
||||||
BindingData& bindingData,
|
BindingData& bindingData,
|
||||||
bool bindingForInitialization = false)
|
bool bindingForInitialization,
|
||||||
|
Model::DmlDispatchableDesc::DmlCompileType compileType)
|
||||||
{
|
{
|
||||||
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
|
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
|
||||||
|
|
||||||
|
@ -47,22 +49,23 @@ void FillBindingData(
|
||||||
|
|
||||||
if (bindingIterator == bindings.end())
|
if (bindingIterator == bindings.end())
|
||||||
{
|
{
|
||||||
if (bindPoints[i].required && !bindingForInitialization)
|
|
||||||
{
|
|
||||||
if (!initializeBindings || initializeBindings->find(bindPointName) == initializeBindings->end())
|
|
||||||
{
|
|
||||||
throw std::invalid_argument(fmt::format("Nothing bound for required tensor '{}'.", bindPointName));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
|
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
|
||||||
{
|
{
|
||||||
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
|
if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i].requiredBinding)
|
||||||
bindingData.bufferBindings[bufferIndex].Offset = 0;
|
{
|
||||||
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
|
// Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
|
||||||
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
|
bindingData.bindingDescs.pop_back();
|
||||||
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
|
bindingData.bufferBindings.pop_back();
|
||||||
bufferIndex++;
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
|
||||||
|
bindingData.bufferBindings[bufferIndex].Offset = 0;
|
||||||
|
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
|
||||||
|
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
|
||||||
|
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
|
||||||
|
bufferIndex++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -103,11 +106,82 @@ void FillBindingData(
|
||||||
|
|
||||||
void DmlDispatchable::Initialize()
|
void DmlDispatchable::Initialize()
|
||||||
{
|
{
|
||||||
THROW_IF_FAILED(m_device->DML()->CompileOperator(
|
if(m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
|
||||||
m_operator.Get(),
|
{
|
||||||
m_desc.executionFlags,
|
m_logger->LogInfo("Compile Op");
|
||||||
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
|
THROW_IF_FAILED(m_device->DML()->CompileOperator(
|
||||||
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
|
m_operator.Get(),
|
||||||
|
m_desc.executionFlags,
|
||||||
|
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
|
||||||
|
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
m_logger->LogInfo("Compiling op using IDMLDevice1::CompileGraph");
|
||||||
|
DML_GRAPH_DESC dmlGraphDesc = {};
|
||||||
|
std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
|
||||||
|
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
|
||||||
|
|
||||||
|
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
|
||||||
|
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
|
||||||
|
DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
|
||||||
|
DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};
|
||||||
|
|
||||||
|
nodeDesc.Operator = m_operator.Get();
|
||||||
|
nodeDesc.Name = m_name.c_str();
|
||||||
|
|
||||||
|
{
|
||||||
|
dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
|
||||||
|
dmlGraphNodeDesc.Desc = &nodeDesc;
|
||||||
|
}
|
||||||
|
|
||||||
|
dmlInputGraphEdges.resize(m_desc.bindPoints.inputs.size());
|
||||||
|
for( size_t i = 0; i < m_desc.bindPoints.inputs.size(); i++)
|
||||||
|
{
|
||||||
|
if (m_desc.bindPoints.inputs[i].requiredBinding)
|
||||||
|
{
|
||||||
|
DML_INPUT_GRAPH_EDGE_DESC desc = {};
|
||||||
|
desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
|
||||||
|
desc.ToNodeIndex = 0;
|
||||||
|
desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
|
||||||
|
desc.Name = m_desc.bindPoints.inputs[i].name.c_str();
|
||||||
|
dmlInputGraphEdges[i] = desc;
|
||||||
|
dmlInputEdges.push_back({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dmlOutputGraphEdges.resize(m_desc.bindPoints.outputs.size());
|
||||||
|
for( size_t i = 0; i < m_desc.bindPoints.outputs.size(); i++)
|
||||||
|
{
|
||||||
|
if (m_desc.bindPoints.outputs[i].requiredBinding)
|
||||||
|
{
|
||||||
|
DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
|
||||||
|
desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
|
||||||
|
desc.FromNodeIndex = 0;
|
||||||
|
desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
|
||||||
|
desc.Name = m_desc.bindPoints.outputs[i].name.c_str();
|
||||||
|
dmlOutputGraphEdges[i] = desc;
|
||||||
|
dmlOutputEdges.push_back({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dmlGraphDesc.InputCount = static_cast<uint32_t>(dmlInputEdges.size());
|
||||||
|
dmlGraphDesc.InputEdges = dmlInputEdges.data();
|
||||||
|
dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount;
|
||||||
|
|
||||||
|
dmlGraphDesc.OutputCount = static_cast<uint32_t>(dmlOutputEdges.size());
|
||||||
|
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
|
||||||
|
dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount;
|
||||||
|
|
||||||
|
dmlGraphDesc.IntermediateEdgeCount = 0;
|
||||||
|
dmlGraphDesc.IntermediateEdges = nullptr;
|
||||||
|
|
||||||
|
dmlGraphDesc.NodeCount = 1;
|
||||||
|
dmlGraphDesc.Nodes = &dmlGraphNodeDesc;
|
||||||
|
|
||||||
|
THROW_IF_FAILED(m_device->DML()->CompileGraph(&dmlGraphDesc, m_desc.executionFlags, IID_PPV_ARGS(&m_operatorCompiled)));
|
||||||
|
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(fmt::format("Graph_{}", m_name)).data());
|
||||||
|
}
|
||||||
|
|
||||||
ComPtr<IDMLOperatorInitializer> initializer;
|
ComPtr<IDMLOperatorInitializer> initializer;
|
||||||
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
|
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
|
||||||
|
@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
|
||||||
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
|
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
|
||||||
// be bound using a separate buffer array binding.
|
// be bound using a separate buffer array binding.
|
||||||
BindingData inputBindingData = {};
|
BindingData inputBindingData = {};
|
||||||
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true);
|
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true, m_desc.compileType);
|
||||||
|
|
||||||
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
|
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
|
||||||
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
|
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
|
||||||
|
@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
|
||||||
auto bindingProps = m_operatorCompiled->GetBindingProperties();
|
auto bindingProps = m_operatorCompiled->GetBindingProperties();
|
||||||
|
|
||||||
BindingData inputBindingData = {};
|
BindingData inputBindingData = {};
|
||||||
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData);
|
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData, false, m_desc.compileType);
|
||||||
|
|
||||||
BindingData outputBindingData = {};
|
BindingData outputBindingData = {};
|
||||||
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData);
|
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData, false, m_desc.compileType);
|
||||||
|
|
||||||
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
|
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
|
||||||
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
|
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
|
||||||
|
|
|
@ -7,7 +7,8 @@ public:
|
||||||
std::string_view name,
|
std::string_view name,
|
||||||
std::shared_ptr<Device> device,
|
std::shared_ptr<Device> device,
|
||||||
const Model::DmlDispatchableDesc& desc,
|
const Model::DmlDispatchableDesc& desc,
|
||||||
const Dispatchable::Bindings& initBindings);
|
const Dispatchable::Bindings& initBindings,
|
||||||
|
IDxDispatchLogger* logger);
|
||||||
|
|
||||||
void Initialize() final;
|
void Initialize() final;
|
||||||
void Bind(const Bindings& bindings, uint32_t iteration) final;
|
void Bind(const Bindings& bindings, uint32_t iteration) final;
|
||||||
|
@ -23,4 +24,5 @@ private:
|
||||||
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
|
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
|
||||||
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
|
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
|
||||||
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
|
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
|
||||||
|
Microsoft::WRL::ComPtr<IDxDispatchLogger> m_logger;
|
||||||
};
|
};
|
|
@ -162,7 +162,7 @@ Executor::Executor(Model& model, std::shared_ptr<Device> device, const CommandLi
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings);
|
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings, m_logger.Get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch(const std::exception& e)
|
catch(const std::exception& e)
|
||||||
|
|
|
@ -1402,11 +1402,58 @@ std::vector<Model::BufferBindingSource> ParseBindingSource(const rapidjson::Valu
|
||||||
return sourceResources;
|
return sourceResources;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileType(const rapidjson::Value& value)
|
||||||
|
{
|
||||||
|
if (value.GetType() != rapidjson::Type::kStringType)
|
||||||
|
{
|
||||||
|
throw std::invalid_argument("Expected a string.");
|
||||||
|
}
|
||||||
|
auto valueString = value.GetString();
|
||||||
|
if (!strcmp(valueString, "DmlCompileOp")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp; }
|
||||||
|
if (!strcmp(valueString, "DmlCompileGraph")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph; }
|
||||||
|
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DmlCompileType.", valueString));
|
||||||
|
}
|
||||||
|
|
||||||
|
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, Model::DmlDispatchableDesc::DmlCompileType defaultValue)
|
||||||
|
{
|
||||||
|
return ParseFieldHelper<Model::DmlDispatchableDesc::DmlCompileType>(object, fieldName, required, defaultValue, [](auto& value) {
|
||||||
|
return ParseDmlCompileType(value);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
|
Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
|
||||||
{
|
{
|
||||||
Model::DmlDispatchableDesc desc;
|
Model::DmlDispatchableDesc desc;
|
||||||
desc.desc = ParseDmlOperatorDesc(object, false, allocator);
|
desc.desc = ParseDmlOperatorDesc(object, false, allocator);
|
||||||
desc.bindPoints = GetBindPoints(*desc.desc);
|
desc.bindPoints = GetBindPoints(*desc.desc);
|
||||||
|
|
||||||
|
// DirectML requires optional bindings if DML_OPERATOR_DESC declares that binding for optional operator tensors.
|
||||||
|
// Logic is based on the Model directml Operator the tensors declared in "desc".
|
||||||
|
auto UpdateBindingPoints = [](const rapidjson::Value& object, std::vector<Model::DmlDispatchableDesc::BindPoint>& bindPoints) {
|
||||||
|
for (auto& bindPoint : bindPoints)
|
||||||
|
{
|
||||||
|
if (bindPoint.required || object.HasMember(bindPoint.name.c_str()))
|
||||||
|
{
|
||||||
|
bindPoint.requiredBinding = true;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
bindPoint.requiredBinding = false;
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
|
||||||
|
auto descMember = object.FindMember("Desc");
|
||||||
|
if (descMember == object.MemberEnd())
|
||||||
|
{
|
||||||
|
descMember = object.FindMember("desc");
|
||||||
|
}
|
||||||
|
if (descMember != object.MemberEnd())
|
||||||
|
{
|
||||||
|
UpdateBindingPoints(descMember->value, desc.bindPoints.inputs);
|
||||||
|
UpdateBindingPoints(descMember->value, desc.bindPoints.outputs);
|
||||||
|
}
|
||||||
|
desc.compileType = ParseDmlCompileTypeField(object, "dmlCompileType", false, Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp);
|
||||||
|
|
||||||
desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);
|
desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);
|
||||||
|
|
||||||
auto bindingsField = object.FindMember("bindings");
|
auto bindingsField = object.FindMember("bindings");
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include <DirectML.h>
|
#include <DirectML.h>
|
||||||
#include "BucketAllocator.h"
|
#include "BucketAllocator.h"
|
||||||
|
|
||||||
|
|
||||||
class Model
|
class Model
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
@ -58,11 +59,17 @@ public:
|
||||||
|
|
||||||
struct DmlDispatchableDesc
|
struct DmlDispatchableDesc
|
||||||
{
|
{
|
||||||
|
enum class DmlCompileType
|
||||||
|
{
|
||||||
|
DmlCompileOp,
|
||||||
|
DmlCompileGraph
|
||||||
|
};
|
||||||
struct BindPoint
|
struct BindPoint
|
||||||
{
|
{
|
||||||
std::string name;
|
std::string name;
|
||||||
uint32_t resourceCount;
|
uint32_t resourceCount;
|
||||||
bool required;
|
bool required;
|
||||||
|
bool requiredBinding;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BindPoints
|
struct BindPoints
|
||||||
|
@ -74,6 +81,7 @@ public:
|
||||||
DML_OPERATOR_DESC* desc;
|
DML_OPERATOR_DESC* desc;
|
||||||
BindPoints bindPoints;
|
BindPoints bindPoints;
|
||||||
DML_EXECUTION_FLAGS executionFlags;
|
DML_EXECUTION_FLAGS executionFlags;
|
||||||
|
DmlCompileType compileType;
|
||||||
Bindings initBindings;
|
Bindings initBindings;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче