CNTK/Source/CNTKv2LibraryDll/BlockFunction.h

185 строки
8.9 KiB
C++

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "PrimitiveFunction.h"
#include "Utils.h"
#include "Variable.h"
namespace CNTK
{
class BlockFunction final : public PrimitiveFunction
{
public:
BlockFunction(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockOpName, Dictionary&& attributes, const std::wstring& blockName = L"", const std::wstring& uid = GenerateUid(PrimitiveOpType::Block))
: PrimitiveFunction(PrimitiveOpType::Block, DetermineInputs(composite, argumentsMap, blockName), std::move(attributes), blockName, uid),
m_composite(composite), m_blockOpName(blockOpName)
{
}
virtual const std::wstring& OpName() const override { return m_blockOpName; }
const FunctionPtr& Composite() const { return m_composite; }
// Mapping from each argument of the composite underlying the block to the corresponding Variable it is mapped to
std::vector<std::pair<Variable, Variable>> CompositeArgumentsMap() const
{
std::vector<std::pair<Variable, Variable>> argumentsMap;
auto arguments = m_composite->Arguments();
for (auto argument : arguments)
{
if (argument.BlockFunctionVariableMapping() == Variable())
LogicError("BlockFunction '%S' with OpName '%S' does not have a mapping for argument '%S'.", AsString().c_str(), OpName().c_str(), argument.AsString().c_str());
argumentsMap.push_back({ argument, argument.BlockFunctionVariableMapping() });
}
// Now sort the mapping by the order of occurence of the argument mapping in the block's inputs
auto blockInputs = Inputs();
std::unordered_map<Variable, size_t> inputIndices;
for (size_t i = 0; i < blockInputs.size(); ++i)
inputIndices.insert({ blockInputs[i], i });
std::stable_sort(argumentsMap.begin(), argumentsMap.end(), [&inputIndices](const std::pair<Variable, Variable>& first, const std::pair<Variable, Variable>& second) {
return inputIndices.at(first.second) < inputIndices.at(second.second);
});
return argumentsMap;
}
// Mapping from each output of the block to the corresponding output of underlying composite
std::unordered_map<Variable, Variable> CompositeOutputsMap() const
{
std::unordered_map<Variable, Variable> outputsMap;
auto outputs = RawOutputs();
for (auto output : outputs)
{
if (output.BlockFunctionVariableMapping() == Variable())
LogicError("BlockFunction '%S' with OpName '%S' does not have a mapping for output '%S'", AsString().c_str(), OpName().c_str(), output.AsString().c_str());
outputsMap[output] = output.BlockFunctionVariableMapping();
}
return outputsMap;
}
protected:
virtual void OnPlaceholdersReplaced(const std::unordered_map<Variable, Variable>& placeholderReplacements,
std::unordered_set<Variable>& replacedPlaceholders) override
{
// Substitute any placeholder replacements in the arguments map
auto arguments = m_composite->Arguments();
std::unordered_map<Variable, Variable> blockCompositePlaceholderReplacements;
for (auto argument : arguments)
{
if (replacedPlaceholders.find(argument.BlockFunctionVariableMapping()) != replacedPlaceholders.end())
{
auto replacement = placeholderReplacements.at(argument.BlockFunctionVariableMapping());
if (IsArgument(replacement))
argument.m_dataFields->m_blockFunctionVariableMapping = replacement;
else
blockCompositePlaceholderReplacements.insert({ argument, replacement });
}
}
m_composite->ReplacePlaceholders(blockCompositePlaceholderReplacements);
// Because some placeholders were replaced in the composite, the inputs of the block became stale,
// so we need to update them to match the underlying composite function.
m_inputs = DetermineInputs(m_composite, CompositeArgumentsMap(), Name());
}
private:
/*static*/ std::vector<Variable> DetermineInputs(const FunctionPtr& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockName) const
{
std::unordered_map<Variable, Variable> argumentsMappingAsMap;
for (auto argumentMapping : argumentsMap)
{
auto wasInserted = argumentsMappingAsMap.insert(argumentMapping).second;
if (!wasInserted)
InvalidArgument("Multiple mappings provided for argument '%S' of the Block composite '%S'", argumentMapping.first.AsString().c_str(), composite->AsString().c_str());
}
std::vector<Variable> blockFunctionInputs;
auto compositeInputs = composite->Inputs();
std::vector<Variable> unmappedArguments;
for (auto compositeInput : compositeInputs)
{
assert(!compositeInput.IsOutput());
if (compositeInput.IsConstant() || compositeInput.IsParameter())
blockFunctionInputs.push_back(compositeInput);
else
{
if (!compositeInput.IsPlaceholder())
{
InvalidArgument("The composite implementing Block '%S' has an argument '%S' which is not a placeholder. "
"All arguments of the composite underlying a Block must be placeholders",
blockName.c_str(), compositeInput.AsString().c_str());
}
// Verify that a mapping was provided for each argument of the composite
if (argumentsMappingAsMap.find(compositeInput) == argumentsMappingAsMap.end())
unmappedArguments.push_back(compositeInput);
}
}
if (!unmappedArguments.empty())
{
InvalidArgument("%zu of the arguments '%S' of the underlying composite Function of Block '%S' have not been mapped when encapsulating the composite as a Block.",
unmappedArguments.size(), NamedListString(unmappedArguments).c_str(), blockName.c_str());
}
// We now append the mapped arguments of the composite to the block inputs in the order of the map
// instead of the original order they appear in the composite itself
for (auto argumentMapping : argumentsMap)
{
argumentMapping.first.m_dataFields->m_blockFunctionVariableMapping = argumentMapping.second;
blockFunctionInputs.push_back(argumentMapping.second);
}
return blockFunctionInputs;
}
void InferOutputs(std::vector<Variable>& outputs) override
{
// We determine the outputs by replacing the arguments of the composite with new placeholders with updated
// shape etc. information matching the corresponding mapped input
auto currentArguments = m_composite->Arguments();
std::unordered_map<Variable, Variable> replacementMap;
for (auto currentArgument : currentArguments)
{
auto currentArgumentMapping = currentArgument.BlockFunctionVariableMapping();
auto newArgument = PlaceholderLike(currentArgumentMapping);
newArgument.m_dataFields->m_blockFunctionVariableMapping = currentArgumentMapping;
replacementMap.insert({ currentArgument, newArgument });
}
m_composite->ReplacePlaceholders(replacementMap);
auto compositeOutputs = m_composite->RawOutputs();
for (auto compositeOutput : compositeOutputs)
{
auto output = OutputVariable(compositeOutput.Shape(), compositeOutput.GetDataType(), compositeOutput.DynamicAxes(), compositeOutput.NeedsGradient(), Name());
output.m_dataFields->m_blockFunctionVariableMapping = compositeOutput;
outputs.push_back(output);
}
}
private:
FunctionPtr m_composite;
std::wstring m_blockOpName;
// Increasing s_serializationVersion every time we add more ops allows us to print
// a more meaningful message when trying to load a new model with a stale binary.
static const size_t s_serializationVersion = 1;
};
}