Integrate alrezni/v2_dropout into master
This commit is contained in:
Коммит
94993f3c81
|
@ -1531,7 +1531,15 @@ namespace CNTK
|
|||
DictionaryIterator end() const { return m_dictionaryData->end(); }
|
||||
ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); }
|
||||
|
||||
size_t Size() { return m_dictionaryData->size(); }
|
||||
size_t Size() const { return m_dictionaryData->size(); }
|
||||
|
||||
std::unordered_set<std::wstring> Keys()
|
||||
{
|
||||
std::unordered_set<std::wstring> keys;
|
||||
for (const auto& kv : *m_dictionaryData)
|
||||
keys.insert(kv.first);
|
||||
return keys;
|
||||
}
|
||||
|
||||
friend CNTK_API std::istream& operator>>(std::istream& stream, Dictionary& us);
|
||||
friend CNTK_API std::ostream& operator<<(std::ostream& stream, const Dictionary& us);
|
||||
|
@ -3391,20 +3399,17 @@ namespace CNTK
|
|||
///
|
||||
/// Create an instance of the random_sample operation on specified sampling weights input vector
|
||||
///
|
||||
// TODO: The initial random seed should be specifiable
|
||||
CNTK_API FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name /*= L""*/);
|
||||
CNTK_API FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the random_sample_inclusion_frequency operation on specified sampling weights input vector
|
||||
///
|
||||
// TODO: The initial random seed should be specifiable
|
||||
CNTK_API FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name /*= L""*/);
|
||||
CNTK_API FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the dropout operation on specified tensor input operand
|
||||
///
|
||||
// TODO: The initial random seed should be specifiable
|
||||
CNTK_API FunctionPtr Dropout(const Variable& operand, double dropoutRate, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Dropout(const Variable& operand, double dropoutRate, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the reshape operation on specified tensor input operand
|
||||
|
@ -4605,7 +4610,8 @@ namespace CNTK
|
|||
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
||||
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
||||
|
||||
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState);
|
||||
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState,
|
||||
const Dictionary& externalState, const Dictionary& distributedState = {});
|
||||
|
||||
void UpdateTrainingProgress(size_t numSamples, const ValuePtr& loss, const ValuePtr& evalCriterion, const DeviceDescriptor& computeDevice);
|
||||
void AddProgressWriters(const std::vector<ProgressWriterPtr>& progressWriters);
|
||||
|
|
|
@ -239,6 +239,8 @@ namespace CNTK
|
|||
|
||||
CNTK_API size_t NewUniqueId();
|
||||
|
||||
CNTK_API size_t GenerateRandomSeed();
|
||||
|
||||
// Internal hooks for testing and higher-level bindings
|
||||
// These should not be directly called by C++ API users
|
||||
CNTK_API void EnableReversingTensorShapesInErrorMessages();
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "Utils.h"
|
||||
#include "Variable.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
|
|
@ -193,4 +193,4 @@
|
|||
<Warning Condition="!$(HasProtobuf)" Text="CNTKv2LibraryDll requires Protocol Buffers to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#protobuf for installation instructions." />
|
||||
<Error Condition="!$(HasBoost)" Text="CNTKv2LibraryDll requires the Boost library to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#boost for installation instructions." />
|
||||
</Target>
|
||||
</Project>
|
||||
</Project>
|
|
@ -28,12 +28,36 @@ namespace CNTK
|
|||
{
|
||||
namespace Internal
|
||||
{
|
||||
static std::atomic<unsigned long long> s_nextUniqueId(0);
|
||||
static std::atomic_ullong s_nextUniqueId = ATOMIC_VAR_INIT(0);
|
||||
size_t NewUniqueId()
|
||||
{
|
||||
return s_nextUniqueId++;
|
||||
}
|
||||
|
||||
static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0);
|
||||
|
||||
// This is used to generate a default seed for stateful nodes (dropout, and both
|
||||
// flavors of random sample). As a result, in distributed environment, each worker
|
||||
// ends up having a different seed.
|
||||
|
||||
size_t GenerateRandomSeed()
|
||||
{
|
||||
static size_t numWorkers = 1, rank = 0;
|
||||
static bool initialized = false;
|
||||
if (MPIWrapper::GetTotalNumberOfMPINodes() != 0 && !initialized)
|
||||
{
|
||||
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||
numWorkers = communicator->Workers().size();
|
||||
rank = communicator->CurrentWorker().m_globalRank;
|
||||
|
||||
if (numWorkers < 1)
|
||||
numWorkers = 1;
|
||||
}
|
||||
|
||||
initialized = true;
|
||||
return (numWorkers * s_currentRandomSeed++) + rank;
|
||||
}
|
||||
|
||||
std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
|
||||
void EnableReversingTensorShapesInErrorMessages()
|
||||
{
|
||||
|
|
|
@ -46,8 +46,58 @@ namespace CNTK
|
|||
return dict;
|
||||
}
|
||||
|
||||
// Copy the internal state from the network into the function graph,
|
||||
// specifically from RngUser nodes into the attributes dictionaries of
|
||||
// the corresponding stateful primitive functions.
|
||||
void CompositeFunction::UpdateInternalState() const
|
||||
{
|
||||
if (!m_computationNetwork)
|
||||
return;
|
||||
|
||||
for (auto& function : m_allPrimitiveFunctions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
// TODO: same for BatchNorm
|
||||
|
||||
auto& outputs = primitiveFunction->RawOutputs();
|
||||
if (outputs.size() != 1)
|
||||
LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());
|
||||
|
||||
const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();
|
||||
|
||||
Dictionary state;
|
||||
state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
|
||||
state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
|
||||
primitiveFunction->SetState(state);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a dictionary representing the internal (local) state of the function graph.
|
||||
Dictionary CompositeFunction::GetInternalState() const
|
||||
{
|
||||
UpdateInternalState();
|
||||
|
||||
Dictionary stateDictionary;
|
||||
for (auto& function : m_allPrimitiveFunctions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
// TODO: same for BatchNorm
|
||||
|
||||
stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
|
||||
}
|
||||
return stateDictionary;
|
||||
}
|
||||
|
||||
/*virtual*/ Dictionary CompositeFunction::Serialize() const
|
||||
{
|
||||
UpdateInternalState();
|
||||
|
||||
Dictionary dict = SerializeBlockComposite();
|
||||
|
||||
// Find cycles in the graph and "break" them by inserting placeholders.
|
||||
|
@ -129,29 +179,6 @@ namespace CNTK
|
|||
}
|
||||
|
||||
dict[functionsKey] = std::move(functionDictionaries);
|
||||
|
||||
// Now, collect and store the internal state for all non-pure (stateful) functions in the graph
|
||||
// (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
|
||||
Dictionary stateDictionary;
|
||||
for (const auto& kv : m_variableToNodeMap)
|
||||
{
|
||||
if (kv.second->Is<RngUser>() && kv.first.IsOutput())
|
||||
{
|
||||
// The RNG state should be associated with the actual function that the computation node
|
||||
// corresponds to, and not the block primitives that wrap the actual function
|
||||
auto ownerFunction = kv.first.Owner().get();
|
||||
if (!ownerFunction->IsBlock())
|
||||
{
|
||||
auto rng = kv.second->As<RngUser>();
|
||||
Dictionary state;
|
||||
state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
|
||||
state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
|
||||
stateDictionary[ownerFunction->Uid()] = state;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dict[stateKey] = std::move(stateDictionary);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
@ -217,10 +244,6 @@ namespace CNTK
|
|||
uidToInputMap[inputVar.Uid()] = inputVar;
|
||||
}
|
||||
|
||||
Dictionary stateDictionary;
|
||||
if (dict.Contains(stateKey))
|
||||
stateDictionary = dict[stateKey].Value<Dictionary>();
|
||||
|
||||
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
|
||||
|
||||
std::unordered_map<Variable, Variable> allPlaceholderReplacements;
|
||||
|
@ -238,25 +261,6 @@ namespace CNTK
|
|||
if (opType == PrimitiveOpType::Combine)
|
||||
continue;
|
||||
|
||||
if (primitiveFunction->IsStateful())
|
||||
{
|
||||
if (stateDictionary.Contains(primitiveFunction->Uid()))
|
||||
{
|
||||
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
|
||||
auto seed = state[rngSeedKey].Value<size_t>();
|
||||
auto offset = state[rngOffsetKey].Value<size_t>();
|
||||
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
|
||||
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
|
||||
}
|
||||
else if (Internal::GetComputationNetworkTraceLevel() > 0)
|
||||
{
|
||||
// TODO: all logging functionality should be refactored to live in a logging utility class.
|
||||
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
|
||||
"when deserializing from a dictionary (version=%zu). "
|
||||
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& output : root->RawOutputs())
|
||||
{
|
||||
const auto& it = uidToInputMap.find(output.Uid());
|
||||
|
@ -276,63 +280,122 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
|
||||
// corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
|
||||
if (version < 3)
|
||||
RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);
|
||||
|
||||
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
|
||||
}
|
||||
|
||||
void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions)
|
||||
{
|
||||
Dictionary stateDictionary;
|
||||
if (dict.Contains(stateKey))
|
||||
stateDictionary = dict[stateKey].Value<Dictionary>();
|
||||
|
||||
for (auto& function : functions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
if (stateDictionary.Contains(primitiveFunction->Uid()))
|
||||
{
|
||||
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
|
||||
// Add key-value pairs expected by the SetState method to the state dictionary.
|
||||
state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
|
||||
state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
|
||||
primitiveFunction->SetState(state);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (Internal::GetComputationNetworkTraceLevel() > 0) {
|
||||
// TODO: all logging functionality should be refactored to live in a logging utility class.
|
||||
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
|
||||
"when deserializing from a dictionary (version=%zu). "
|
||||
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
|
||||
}
|
||||
|
||||
// Create state from scratch, so that function attributes contain all the required key-value pairs.
|
||||
Dictionary state;
|
||||
state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed();
|
||||
state[PrimitiveFunction::AttributeNameRngOffset] = 0;
|
||||
primitiveFunction->SetState(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CompositeFunction::CopyState(const CompositeFunction& source)
|
||||
{
|
||||
// Create a map with all non-pure (stateful) functions in the function graph.
|
||||
auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
|
||||
std::map<std::wstring, FunctionPtr> functionMap;
|
||||
for (auto funcPtr : allPrimitiveFunctions)
|
||||
{
|
||||
// Collect a vector of stateful funciton uids using a pre-order traversal of a function graphs.
|
||||
auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
|
||||
vector<wstring> uids;
|
||||
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
|
||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
|
||||
if (primitiveFunction->IsStateful())
|
||||
if (primitiveFunction->IsStateful())
|
||||
{
|
||||
functionMap[primitiveFunction->Uid()] = funcPtr;
|
||||
uids.push_back(funcPtr->Uid());
|
||||
}
|
||||
}
|
||||
return functionMap;
|
||||
}, true);
|
||||
|
||||
return uids;
|
||||
};
|
||||
|
||||
std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
|
||||
std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);
|
||||
auto theirUIDs = collectStatefulFunctionUIDs(source);
|
||||
auto ourUIDs = collectStatefulFunctionUIDs(*this);
|
||||
|
||||
assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
|
||||
if (statefulFunctionsFrom.size() == 0)
|
||||
if (theirUIDs.size() != ourUIDs.size())
|
||||
CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");
|
||||
|
||||
auto state = source.GetInternalState();
|
||||
|
||||
if (theirUIDs == ourUIDs)
|
||||
{
|
||||
// uids are identialy, no need to remap.
|
||||
SetInternalState(state);
|
||||
return;
|
||||
}
|
||||
|
||||
// build a map of souce funtion to the destination (this) function UIDs.
|
||||
map<wstring, wstring> uidMap;
|
||||
for (auto i = 0; i < theirUIDs.size(); i++)
|
||||
uidMap[theirUIDs[i]] = ourUIDs[i];
|
||||
|
||||
Dictionary remappedState;
|
||||
for (auto& kv : state)
|
||||
remappedState[uidMap[kv.first]] = kv.second;
|
||||
|
||||
// Copy state captured in the attributes dictionaries.
|
||||
for (const auto& kv : statefulFunctionsFrom)
|
||||
{
|
||||
statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
|
||||
}
|
||||
|
||||
UpdateInternalNetworkState();
|
||||
SetInternalState(remappedState);
|
||||
}
|
||||
|
||||
void CompositeFunction::UpdateInternalNetworkState()
|
||||
void CompositeFunction::SetInternalState(const Dictionary& state)
|
||||
{
|
||||
if (!m_computationNetwork)
|
||||
{
|
||||
if (state.Size() == 0)
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto& function : m_allPrimitiveFunctions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
|
||||
if (primitiveFunction->IsStateful())
|
||||
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
|
||||
|
||||
primitiveFunction->SetState(functionState);
|
||||
|
||||
if (!m_computationNetwork)
|
||||
continue;
|
||||
|
||||
auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||
auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||
|
||||
// copy the state directly into the network
|
||||
for (const auto& output : function->RawOutputs())
|
||||
{
|
||||
for (const auto& output : function->RawOutputs())
|
||||
{
|
||||
auto node = m_variableToNodeMap.at(output);
|
||||
auto attributes = function->Attributes();
|
||||
auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||
auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||
node->As<RngUser>()->SetRngState(seed, offset);
|
||||
}
|
||||
auto node = m_variableToNodeMap.at(output);
|
||||
node->As<RngUser>()->SetRngState(seed, offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -895,16 +958,9 @@ namespace CNTK
|
|||
|
||||
if (computationNodePtr->Is<RngUser>())
|
||||
{
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
|
||||
{
|
||||
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||
uint64_t offset = 0;
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
|
||||
{
|
||||
offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||
}
|
||||
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
|
||||
}
|
||||
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
|
|
@ -110,7 +110,7 @@ namespace CNTK
|
|||
Dictionary SerializeBlockComposite() const;
|
||||
|
||||
virtual Dictionary Serialize() const override;
|
||||
|
||||
|
||||
virtual size_t CurrentVersion() const override { return s_serializationVersion; }
|
||||
|
||||
static FunctionPtr DeserializeBlockComposite(const Dictionary& dict,
|
||||
|
@ -238,12 +238,26 @@ namespace CNTK
|
|||
return inputs;
|
||||
}
|
||||
|
||||
// If the network is already created, copy internal state over from the functions in the graph into the underlying network.
|
||||
void UpdateInternalNetworkState();
|
||||
|
||||
// Copy state info from source function graph into' this' function graph.
|
||||
// Copy the internal state from the network into the function graph.
|
||||
void UpdateInternalState() const;
|
||||
|
||||
// Generate a dictionary representing the internal (local) state of the function graph.
|
||||
Dictionary GetInternalState() const;
|
||||
|
||||
// Update the internal state using the provided dictionary.
|
||||
// If the network is already created, directly update its state. Otherwise, copy the state from the
|
||||
// dictionary into the function graph.
|
||||
void SetInternalState(const Dictionary& state);
|
||||
|
||||
// Copy state info from source function graph into 'this' function graph.
|
||||
// Both graphs must be equivalent.
|
||||
void CopyState(const CompositeFunction& source);
|
||||
|
||||
// This function is only needed for backwards compatibility to support deserializing composite funcitions that
|
||||
// stored the internal state inside a dedicated value in the dictionary.
|
||||
static void RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> PrimitiveFunctions);
|
||||
|
||||
static Variable GetMappingForNoOpOutput(const Variable& variable, bool recursive = false);
|
||||
static Variable GetMappingVariable(const Variable& variable, bool recursive = false);
|
||||
|
||||
|
@ -328,6 +342,7 @@ namespace CNTK
|
|||
// Version history:
|
||||
// 1 -- initial version.
|
||||
// 2 -- add support for stateful functions (with corresponding nodes inheriting from RngUser).
|
||||
static const size_t s_serializationVersion = 2;
|
||||
// 3 -- store internal function state directly in the attributes dictionary.
|
||||
static const size_t s_serializationVersion = 3;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "PrimitiveFunction.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include "BlockFunction.h"
|
||||
#include "Utils.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -1037,29 +1038,47 @@ namespace CNTK
|
|||
LogicError("Slice: Invalid axis argument provided. Slice along the dynamic batch axis is currently unsupported. To slice a sequence along its ordered dynamic axis use Sequence::Slice.");
|
||||
}
|
||||
|
||||
FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name)
|
||||
FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
||||
|
||||
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||
seed = Internal::GenerateRandomSeed();
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
|
||||
|
||||
return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name)
|
||||
FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
||||
|
||||
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||
seed = Internal::GenerateRandomSeed();
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
|
||||
|
||||
return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr Dropout(const Variable& operand, double dropoutRate, const std::wstring& name)
|
||||
FunctionPtr Dropout(const Variable& operand, double dropoutRate, unsigned long seed, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate;
|
||||
|
||||
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||
seed = Internal::GenerateRandomSeed();
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
|
||||
|
||||
return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
#include "BlockFunction.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include "SpecialPurposeNodes.h"
|
||||
#include "ConvolveGeometry.h"
|
||||
#include "ConvolutionalNodes.h"
|
||||
#include "Variable.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -37,7 +40,7 @@ namespace CNTK
|
|||
|
||||
// Names of the various attributes of CNTK primitive Functions
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxisVec = L"axisVec";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxisVec = L"axisVec";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates";
|
||||
|
@ -976,4 +979,118 @@ namespace CNTK
|
|||
return std::shared_ptr<PrimitiveFunction>(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid),
|
||||
[](PrimitiveFunction* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
static const vector<wstring> s_stateAttributes = { PrimitiveFunction::AttributeNameRngSeed, PrimitiveFunction::AttributeNameRngOffset };
|
||||
|
||||
Dictionary PrimitiveFunction::GetState() const
|
||||
{
|
||||
if (!IsStateful())
|
||||
LogicError("Function '%S' is not stateful.", AsString().c_str());
|
||||
|
||||
Dictionary state;
|
||||
for (auto& key : s_stateAttributes)
|
||||
{
|
||||
state[key] = m_attributes[key];
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
void PrimitiveFunction::SetState(const Dictionary& state)
|
||||
{
|
||||
if (!IsStateful())
|
||||
LogicError("Function '%S' is not stateful.", AsString().c_str());
|
||||
|
||||
for (auto& key : s_stateAttributes)
|
||||
{
|
||||
m_attributes[key] = state[key];
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ void PrimitiveFunction::FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from/* = NDShape()*/)
|
||||
{
|
||||
auto dims = shape.Dimensions();
|
||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(filterRank, inputRank, dims, deflt, from.Dimensions());
|
||||
shape = NDShape(dims);
|
||||
}
|
||||
|
||||
NDShape PrimitiveFunction::ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
|
||||
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
|
||||
bool transpose, bool inferDimensions, bool ceilOutputDim/* = false*/) const
|
||||
{
|
||||
if (inferDimensions)
|
||||
{
|
||||
size_t inputRank = operandShape.Rank();
|
||||
|
||||
// Unknown kernel shape valid only for pooling, however, the shape should have expanded before
|
||||
// this call.
|
||||
if (kernelShape == NDShape::Unknown)
|
||||
RuntimeError("Convolution: Kernel shape can't be Unknown.");
|
||||
|
||||
// infer reduction dimensions if not given
|
||||
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
|
||||
size_t filterRank = kernelShape.Rank();
|
||||
|
||||
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
|
||||
// picking the corresponding operand shape dimensionality
|
||||
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
|
||||
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
|
||||
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
|
||||
{
|
||||
filterRank--;
|
||||
kernelShape = kernelShape.SubShape(0, filterRank);
|
||||
}
|
||||
|
||||
NDShape fromShape;
|
||||
if (op == PrimitiveOpType::Convolution)
|
||||
fromShape = operandShape;
|
||||
|
||||
size_t fillRank = (!transpose) ? filterRank : filterRank - 1;
|
||||
FixNDShape(fillRank, inputRank, kernelShape, 1, fromShape); // convolve over red dim; pool over 1
|
||||
FixNDShape(fillRank, inputRank, strides, 1, fromShape); // stride for reduction dims is red dim or 1
|
||||
FixNDShape(fillRank, inputRank, lowerPad, 0);
|
||||
FixNDShape(fillRank, inputRank, upperPad, 0);
|
||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, sharing, true);
|
||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, autoPad, false); // no padding for reduction dims
|
||||
}
|
||||
|
||||
decltype(&Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape) computeOutputShapeFunc;
|
||||
if (!transpose)
|
||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape;
|
||||
else
|
||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
|
||||
|
||||
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutputDim));
|
||||
}
|
||||
|
||||
/*static*/ bool PrimitiveFunction::UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes)
|
||||
{
|
||||
bool anyParameterOperandDimsInferred = false;
|
||||
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
|
||||
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
|
||||
{
|
||||
operand.m_dataFields->m_shape = newOperandShape;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
for (auto& newOperandShapePair : newOperandShapes)
|
||||
anyParameterOperandDimsInferred = updateOperandShapeFunc(newOperandShapePair.first, newOperandShapePair.second) || anyParameterOperandDimsInferred;
|
||||
|
||||
return anyParameterOperandDimsInferred;
|
||||
}
|
||||
|
||||
NDShape PrimitiveFunction::NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const
|
||||
{
|
||||
assert(operands.size() > 1);
|
||||
|
||||
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
|
||||
Variable dummyOutputVariable = PlaceholderVariable(NDShape());
|
||||
for (auto& operand : operands)
|
||||
dummyOutputVariable.m_dataFields->m_shape = BinaryElementwiseOpOutputShape(op, dummyOutputVariable, operand, broadcastAllowed, inferInputDimensions);
|
||||
|
||||
return dummyOutputVariable.Shape();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,10 +8,6 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "PrimitiveOpType.h"
|
||||
#include "Utils.h"
|
||||
#include "ConvolveGeometry.h"
|
||||
#include "ConvolutionalNodes.h"
|
||||
#include "Variable.h"
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
@ -301,6 +297,10 @@ namespace CNTK
|
|||
(OpType() == PrimitiveOpType::RandomSampleInclusionFrequency);
|
||||
}
|
||||
|
||||
Dictionary GetState() const;
|
||||
|
||||
void SetState(const Dictionary& state);
|
||||
|
||||
private:
|
||||
|
||||
// The following helper functions are used to determine the output shape for different
|
||||
|
@ -425,24 +425,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
// Returns a boolean indicating if any operand shape was updated
|
||||
static bool UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes)
|
||||
{
|
||||
bool anyParameterOperandDimsInferred = false;
|
||||
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
|
||||
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
|
||||
{
|
||||
operand.m_dataFields->m_shape = newOperandShape;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
for (auto& newOperandShapePair : newOperandShapes)
|
||||
anyParameterOperandDimsInferred = updateOperandShapeFunc(newOperandShapePair.first, newOperandShapePair.second) || anyParameterOperandDimsInferred;
|
||||
|
||||
return anyParameterOperandDimsInferred;
|
||||
}
|
||||
static bool UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes);
|
||||
|
||||
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
|
||||
/*static*/ NDShape BinaryElementwiseOpOutputShape(PrimitiveOpType op, Variable& leftOperand, Variable& rightOperand, bool broadcastAllowed, bool inferInputDimensions) const
|
||||
|
@ -493,6 +476,10 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
UNUSED(broadcastAllowed);
|
||||
// BUGBUG: if (broadcastAllowed) is missing here?
|
||||
|
||||
// Broadcast in remaining axes
|
||||
for (size_t i = shapeWithSmallerNumAxes.Rank(); i < numOutputAxes; ++i)
|
||||
outputDims[i] = shapeWithLargerNumAxes[i];
|
||||
|
@ -507,17 +494,7 @@ namespace CNTK
|
|||
return NDShape(std::move(outputDims));
|
||||
}
|
||||
|
||||
/*static*/ NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const
|
||||
{
|
||||
assert(operands.size() > 1);
|
||||
|
||||
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
|
||||
Variable dummyOutputVariable = PlaceholderVariable(NDShape());
|
||||
for (auto& operand : operands)
|
||||
dummyOutputVariable.m_dataFields->m_shape = BinaryElementwiseOpOutputShape(op, dummyOutputVariable, operand, broadcastAllowed, inferInputDimensions);
|
||||
|
||||
return dummyOutputVariable.Shape();
|
||||
}
|
||||
/*static*/ NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const;
|
||||
|
||||
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
|
||||
/*static*/ NDShape TimesOpOutputShape(Variable& leftOperand, Variable& rightOperand, size_t outputRank, int inferInputRankToMap, bool inferInputDimensions) const
|
||||
|
@ -643,61 +620,11 @@ namespace CNTK
|
|||
return NDShape(std::move(outputDims));
|
||||
}
|
||||
|
||||
static void FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from = NDShape())
|
||||
{
|
||||
auto dims = shape.Dimensions();
|
||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(filterRank, inputRank, dims, deflt, from.Dimensions());
|
||||
shape = NDShape(dims);
|
||||
}
|
||||
static void FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from = NDShape());
|
||||
|
||||
/*static*/ NDShape ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
|
||||
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
|
||||
bool transpose, bool inferDimensions, bool ceilOutputDim = false) const
|
||||
{
|
||||
if (inferDimensions)
|
||||
{
|
||||
size_t inputRank = operandShape.Rank();
|
||||
|
||||
// Unknown kernel shape valid only for pooling, however, the shape should have expanded before
|
||||
// this call.
|
||||
if (kernelShape == NDShape::Unknown)
|
||||
RuntimeError("Convolution: Kernel shape can't be Unknown.");
|
||||
|
||||
// infer reduction dimensions if not given
|
||||
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
|
||||
size_t filterRank = kernelShape.Rank();
|
||||
|
||||
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
|
||||
// picking the corresponding operand shape dimensionality
|
||||
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
|
||||
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
|
||||
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
|
||||
{
|
||||
filterRank--;
|
||||
kernelShape = kernelShape.SubShape(0, filterRank);
|
||||
}
|
||||
|
||||
NDShape fromShape;
|
||||
if (op == PrimitiveOpType::Convolution)
|
||||
fromShape = operandShape;
|
||||
|
||||
size_t fillRank = (!transpose)? filterRank : filterRank - 1;
|
||||
FixNDShape(fillRank, inputRank, kernelShape, 1, fromShape); // convolve over red dim; pool over 1
|
||||
FixNDShape(fillRank, inputRank, strides, 1, fromShape); // stride for reduction dims is red dim or 1
|
||||
FixNDShape(fillRank, inputRank, lowerPad, 0);
|
||||
FixNDShape(fillRank, inputRank, upperPad, 0);
|
||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, sharing, true);
|
||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, autoPad, false); // no padding for reduction dims
|
||||
}
|
||||
|
||||
decltype(&Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape) computeOutputShapeFunc;
|
||||
if (!transpose)
|
||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape;
|
||||
else
|
||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
|
||||
|
||||
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutputDim));
|
||||
}
|
||||
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
|
||||
bool transpose, bool inferDimensions, bool ceilOutputDim = false) const;
|
||||
|
||||
/*static*/ NDShape BatchNormalizationOutputShape(std::vector<Variable>& operands, bool spatial, bool inferDimensions) const
|
||||
{
|
||||
|
|
|
@ -41,6 +41,8 @@ namespace CNTK
|
|||
const std::wstring blockFunctionOpNameKey = L"block_function_op_name";
|
||||
const std::wstring blockFunctionCompositeArgumentsMapKeysKey = L"block_function_composite_arguments_map_keys";
|
||||
const std::wstring blockFunctionCompositeArgumentsMapValuesKey = L"block_function_composite_arguments_map_values";
|
||||
const std::wstring internalWorkerStateKey = L"internal_worker_state";
|
||||
const std::wstring externalWorkerStateKey = L"external_worker_state";
|
||||
|
||||
template <typename T>
|
||||
inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion)
|
||||
|
|
|
@ -8,11 +8,21 @@
|
|||
#include "Utils.h"
|
||||
#include "Learner.h"
|
||||
#include "PerformanceProfiler.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include "Serialization.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
const std::wstring versionPropertyName = L"Version";
|
||||
const std::wstring learnersPropertyName = L"Learners";
|
||||
const std::wstring externalStatePropertyName = L"ExternalState";
|
||||
const std::wstring distributedStatePropertyName = L"DistributedState";
|
||||
|
||||
// Version history:
|
||||
// 0 -- a version number before the versioning was introduced for the trainer's checkpoints.
|
||||
// 1 -- initial version: added a key-value pair for the checkpoint version info, added
|
||||
// distributed state key to save all local state collected from distributed workers.
|
||||
static const size_t trainerCheckpointVersion = 1;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
|
@ -307,15 +317,22 @@ namespace CNTK
|
|||
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState)
|
||||
{
|
||||
auto learnersState = m_parameterLearners->CreateCheckpoint();
|
||||
|
||||
if (!m_distributed)
|
||||
return Save(modelFilePath, learnersState, externalState);
|
||||
|
||||
auto compositeFunction = dynamic_cast<CompositeFunction*>(m_combinedTrainingFunction.get());
|
||||
|
||||
Dictionary state;
|
||||
state[internalWorkerStateKey] = compositeFunction->GetInternalState(); // this is the local worker's state.
|
||||
state[externalWorkerStateKey] = externalState;
|
||||
|
||||
// Collect distrbuted external state.
|
||||
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||
communicator->Barrier();
|
||||
|
||||
std::vector<DictionaryPtr> remoteState;
|
||||
communicator->Gather(externalState, remoteState, communicator->Workers());
|
||||
communicator->Gather(state, remoteState, communicator->Workers());
|
||||
|
||||
Dictionary aggregatedState;
|
||||
for (const auto& w : communicator->Workers())
|
||||
|
@ -324,19 +341,21 @@ namespace CNTK
|
|||
}
|
||||
|
||||
if (communicator->CurrentWorker().IsMain())
|
||||
Save(modelFilePath, learnersState, aggregatedState);
|
||||
Save(modelFilePath, learnersState, externalState, aggregatedState);
|
||||
|
||||
// all workers need to sync up after saving model to avoid read-after-write hazard
|
||||
// i.e. one worker is in the middle of write while another tries to read
|
||||
communicator->Barrier();
|
||||
}
|
||||
|
||||
void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState)
|
||||
void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState, const Dictionary& distributedState)
|
||||
{
|
||||
std::wstring tempModelFile = modelFilePath + L".tmp";
|
||||
Dictionary state;
|
||||
state[versionPropertyName] = trainerCheckpointVersion;
|
||||
state[learnersPropertyName] = learnerState;
|
||||
state[externalStatePropertyName] = externalState;
|
||||
state[distributedStatePropertyName] = distributedState;
|
||||
|
||||
m_combinedTrainingFunction->SaveModel(tempModelFile);
|
||||
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
||||
|
@ -359,25 +378,57 @@ namespace CNTK
|
|||
|
||||
Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath));
|
||||
|
||||
size_t version = 0;
|
||||
|
||||
if (checkpoint.Contains(versionPropertyName))
|
||||
version = checkpoint[versionPropertyName].Value<size_t>();
|
||||
|
||||
auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
|
||||
auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();
|
||||
|
||||
m_parameterLearners->RestoreFromCheckpoint(learnerState);
|
||||
|
||||
if (!m_distributed)
|
||||
{
|
||||
m_parameterLearners->RestoreFromCheckpoint(learnerState);
|
||||
return externalState;
|
||||
}
|
||||
|
||||
m_parameterLearners->RestoreFromCheckpoint(learnerState);
|
||||
// this ensures that nobody will start writing to the model/checkpoint files, until
|
||||
// everybody is done reading them.
|
||||
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||
communicator->Barrier();
|
||||
|
||||
auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank);
|
||||
auto mainWorkerId = std::to_wstring(0);
|
||||
auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank);
|
||||
|
||||
if (externalState.Contains(key))
|
||||
// before version 1, there was no distributed state per se. Instead, the external state
|
||||
// contained a dictionary of worker-specific external states.
|
||||
if (version == 0)
|
||||
{
|
||||
auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId;
|
||||
return externalState[key].Value<Dictionary>();
|
||||
else
|
||||
return externalState[std::to_wstring(0)].Value<Dictionary>();
|
||||
}
|
||||
|
||||
Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>();
|
||||
|
||||
if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId))
|
||||
{
|
||||
return externalState;
|
||||
}
|
||||
|
||||
// the checkpoint contains internal state for this worker.
|
||||
Dictionary localState = distributedState[localWorkerId].Value<Dictionary>();
|
||||
|
||||
auto internalState = localState[internalWorkerStateKey].Value<Dictionary>();
|
||||
auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction);
|
||||
if (compositeFunction == nullptr)
|
||||
RuntimeError("Combined training function is not a CompositeFunction.");
|
||||
|
||||
// this assumes the compositeFunction (restored form a checkpoint made by the main node) and
|
||||
// the internal worker state both have identical UIDs.
|
||||
compositeFunction->SetInternalState(internalState);
|
||||
|
||||
return localState[externalWorkerStateKey].Value<Dictionary>();
|
||||
}
|
||||
|
||||
double Trainer::PreviousMinibatchLossAverage() const
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include <fstream>
|
||||
#include "Utils.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
|
|
@ -11,10 +11,10 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
CPURNGHandle::CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
|
||||
: RNGHandle(deviceId)
|
||||
: RNGHandle(deviceId),
|
||||
m_generator(seed)
|
||||
{
|
||||
m_generator.reset(new std::mt19937_64(seed));
|
||||
m_generator->discard(offset);
|
||||
m_generator.discard(offset);
|
||||
}
|
||||
|
||||
}}}
|
||||
|
|
|
@ -20,12 +20,11 @@ public:
|
|||
|
||||
std::mt19937_64& Generator()
|
||||
{
|
||||
return *m_generator;
|
||||
return m_generator;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<std::mt19937_64> m_generator;
|
||||
// TODO: why is this a ptr?
|
||||
std::mt19937_64 m_generator;
|
||||
};
|
||||
|
||||
}}}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "Common.h"
|
||||
|
||||
using namespace CNTK;
|
||||
using namespace std;
|
||||
using namespace std::placeholders;
|
||||
|
||||
extern bool Is1bitSGDAvailable();
|
||||
|
@ -35,14 +36,23 @@ namespace
|
|||
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
|
||||
const size_t totalNumberOfSamples = numSamplesPerSweep * numSweepsToTrainWith;
|
||||
|
||||
|
||||
const std::wstring g_attributeNameRngSeed = L"rngSeed";
|
||||
const std::wstring g_attributeNameRngOffset = L"rngOffset";
|
||||
|
||||
inline MinibatchSourcePtr GetMinibatchSource(const FeedForwardClassifier& classifier)
|
||||
{
|
||||
return TextFormatMinibatchSource(g_inputFile,
|
||||
{ { g_featureStreamName, classifier.inputDim },
|
||||
{ g_labelsStreamName, classifier.ouputDim } },
|
||||
totalNumberOfSamples, true);
|
||||
}
|
||||
|
||||
void LoopBasedOnSamples(const std::wstring& name, const DeviceDescriptor& device, std::function<DistributedLearnerPtr(LearnerPtr)> factory, const FeedForwardClassifier& classifier)
|
||||
{
|
||||
printf("Training loop thru samples with %ls.\n", name.c_str());
|
||||
|
||||
auto minibatchSource = TextFormatMinibatchSource(g_inputFile,
|
||||
{ { g_featureStreamName, classifier.inputDim }, { g_labelsStreamName, classifier.ouputDim } },
|
||||
totalNumberOfSamples,
|
||||
true);
|
||||
auto minibatchSource = GetMinibatchSource(classifier);
|
||||
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
|
||||
|
@ -135,3 +145,92 @@ void TestFrameMode()
|
|||
}
|
||||
sync->Barrier();
|
||||
}
|
||||
|
||||
|
||||
void TestDistributedCheckpointing()
|
||||
{
|
||||
std::vector<DeviceDescriptor> devices;
|
||||
if (ShouldRunOnCpu())
|
||||
devices.push_back(DeviceDescriptor::CPUDevice());
|
||||
if (ShouldRunOnGpu())
|
||||
devices.push_back(DeviceDescriptor::GPUDevice(0));
|
||||
|
||||
auto sync = MPICommunicator();
|
||||
|
||||
auto numWorkers = sync->Workers().size();
|
||||
auto workerRank = sync->CurrentWorker().m_globalRank;
|
||||
|
||||
for (auto device : devices)
|
||||
{
|
||||
|
||||
auto ff = BuildFeedForwardClassifier(device);
|
||||
ff.output = Dropout(ff.output, 0.5);
|
||||
ff.trainingLoss = CNTK::CrossEntropyWithSoftmax(ff.output, ff.labels, L"lossFunction");
|
||||
ff.prediction = CNTK::ClassificationError(ff.output, ff.labels, L"classificationError");
|
||||
|
||||
{
|
||||
auto& attributes = ff.output->RootFunction()->Attributes();
|
||||
size_t seed = attributes[g_attributeNameRngSeed].Value<size_t>();
|
||||
// Check that (1) the seed is in the attributes dictionary and
|
||||
// (2) the auto-generated seed value reflects the workerRank.
|
||||
if (numWorkers > 1 && seed % numWorkers != workerRank)
|
||||
ReportFailure("Unexpected seed value");
|
||||
}
|
||||
|
||||
auto learner = SGDLearner(ff.output->Parameters(), LearningRatePerSampleSchedule(0.02));
|
||||
auto distributedLearner = CreateDataParallelDistributedLearner(MPICommunicator(), learner, 0);
|
||||
auto trainer = CreateTrainer(ff.output, ff.trainingLoss, ff.prediction, { distributedLearner });
|
||||
|
||||
auto minibatchSource = GetMinibatchSource(ff);
|
||||
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
|
||||
|
||||
vector<double> expectedLoss(100);
|
||||
for (int i = 0; i < 100; i++)
|
||||
{
|
||||
if (i % 10 == 0)
|
||||
{
|
||||
auto checkpoint = minibatchSource->GetCheckpointState();
|
||||
trainer->SaveCheckpoint(L"distributed_checkpoint_test." + to_wstring(i), checkpoint);
|
||||
}
|
||||
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
unordered_map<Variable, MinibatchData> minibatch = { { ff.features, minibatchData[featureStreamInfo] },{ ff.labels, minibatchData[labelStreamInfo] } };
|
||||
|
||||
trainer->TrainMinibatch(minibatch, device);
|
||||
expectedLoss[i] = trainer->PreviousMinibatchLossAverage();
|
||||
}
|
||||
|
||||
for (int i = 0; i < 100; i++)
|
||||
{
|
||||
if (i % 10 == 0)
|
||||
{
|
||||
auto checkpoint = trainer->RestoreFromCheckpoint(L"distributed_checkpoint_test." + to_wstring(i));
|
||||
minibatchSource->RestoreFromCheckpoint(checkpoint);
|
||||
|
||||
auto& attributes = ff.output->RootFunction()->Attributes();
|
||||
size_t seed = attributes[g_attributeNameRngSeed].Value<size_t>();
|
||||
size_t offset = attributes[g_attributeNameRngOffset].Value<size_t>();
|
||||
|
||||
// Check that the worker-specific seed value was properly restored from the checkpoint.
|
||||
if (numWorkers > 1 && seed % numWorkers != workerRank)
|
||||
ReportFailure("Unexpected seed value");
|
||||
// Check the offset and verify that it changes depending on the number of processed minibatches.
|
||||
if (offset != i * minibatchSize * ff.inputDim)
|
||||
ReportFailure("Unexpected seed value");
|
||||
}
|
||||
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
unordered_map<Variable, MinibatchData> minibatch = { { ff.features, minibatchData[featureStreamInfo] },{ ff.labels, minibatchData[labelStreamInfo] } };
|
||||
|
||||
trainer->TrainMinibatch(minibatch, device);
|
||||
auto loss = trainer->PreviousMinibatchLossAverage();
|
||||
|
||||
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sync->Barrier();
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ void MNISTClassifierTests();
|
|||
void TrainSequenceToSequenceTranslator();
|
||||
void TrainTruncatedLSTMAcousticModelClassifier();
|
||||
void TestFrameMode();
|
||||
void TestDistributedCheckpointing();
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
|
@ -58,6 +59,8 @@ int main(int argc, char *argv[])
|
|||
|
||||
TestFrameMode();
|
||||
|
||||
TestDistributedCheckpointing();
|
||||
|
||||
std::string testsPassedMsg = "\nCNTKv2Library-Distribution tests: Passed\n";
|
||||
|
||||
printf("%s", testsPassedMsg.c_str());
|
||||
|
|
|
@ -438,6 +438,10 @@ Test module "V2LibraryTests" has passed with:
|
|||
|
||||
Test case "SerializationSuite/CheckpointingWithStatefulNodesInGPU" has passed
|
||||
|
||||
Test case "SerializationSuite/CheckpointingWithStatefulNodesAndExplicitSeedsOnCPU" has passed
|
||||
|
||||
Test case "SerializationSuite/CheckpointingWithStatefulNodesAndExplicitSeedsOnGPU" has passed
|
||||
|
||||
Test suite "FeedForwardSuite" has passed with:
|
||||
6 test cases out of 6 passed
|
||||
8 assertions out of 8 passed
|
||||
|
|
|
@ -815,6 +815,74 @@ void TestCheckpointingWithStatefulNodes(const DeviceDescriptor& device)
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void TestCheckpointingWithStatefulNodesAndExplicitSeeds(const DeviceDescriptor& device)
|
||||
{
|
||||
auto featureStreamName = L"features";
|
||||
auto labelsStreamName = L"labels";
|
||||
|
||||
size_t inputDim = 784;
|
||||
size_t numOutputClasses = 10;
|
||||
auto features = InputVariable({ inputDim }, false /*isSparse*/, DataType::Float, featureStreamName);
|
||||
auto labels = InputVariable({ numOutputClasses }, DataType::Float, labelsStreamName);
|
||||
|
||||
auto net1 = BuildFFClassifierNet(features, numOutputClasses, device, 1);
|
||||
auto net2 = net1->Clone(ParameterCloningMethod::Clone, { { features , features } });
|
||||
auto net3 = net1->Clone(ParameterCloningMethod::Clone, { { features , features } });
|
||||
|
||||
auto trainer1 = BuildTrainer(Dropout(net1, 0.5, 123), labels);
|
||||
auto trainer2 = BuildTrainer(Dropout(net2, 0.5, 123), labels);
|
||||
auto trainer3 = BuildTrainer(Dropout(net3, 0.5, 321), labels);
|
||||
|
||||
const size_t minibatchSize = 50;
|
||||
const size_t maxSamples = 150;
|
||||
auto minibatchSource = TextFormatMinibatchSource(L"Train-28x28_cntk_text.txt", { { featureStreamName, inputDim },{ labelsStreamName, numOutputClasses } }, 2 * maxSamples, false);
|
||||
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||
|
||||
for (int i = 0; i < maxSamples; i+=minibatchSize)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
|
||||
|
||||
trainer1->TrainMinibatch(minibatch, device);
|
||||
trainer2->TrainMinibatch(minibatch, device);
|
||||
trainer3->TrainMinibatch(minibatch, device);
|
||||
auto loss1 = trainer1->PreviousMinibatchLossAverage();
|
||||
auto loss2 = trainer2->PreviousMinibatchLossAverage();
|
||||
auto loss3 = trainer3->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(loss1, loss2, "Training loss does not match expectation");
|
||||
BOOST_TEST((abs(loss1 - loss2) <= abs(loss2 - loss3)));
|
||||
}
|
||||
|
||||
trainer1->SaveCheckpoint(L"seeded_stateful_nodes.model");
|
||||
auto state = minibatchSource->GetCheckpointState();
|
||||
|
||||
vector<double> expectedLoss;
|
||||
for (int i = 0; i < maxSamples; i += minibatchSize)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
|
||||
|
||||
trainer1->TrainMinibatch(minibatch, device);
|
||||
expectedLoss.push_back(trainer1->PreviousMinibatchLossAverage());
|
||||
}
|
||||
|
||||
trainer1->RestoreFromCheckpoint(L"seeded_stateful_nodes.model");
|
||||
minibatchSource->RestoreFromCheckpoint(state);
|
||||
|
||||
for (int i = 0; i*minibatchSize < maxSamples; i++)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
|
||||
|
||||
trainer1->TrainMinibatch(minibatch, device);
|
||||
double loss = trainer1->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
|
||||
}
|
||||
}
|
||||
|
||||
void TestLoadingModelFromMemoryBuffer()
|
||||
{
|
||||
ifstream modelFileStream("batch.norm.no.sample.count.v2.bin", ifstream::binary);
|
||||
|
@ -992,6 +1060,18 @@ BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesInGPU)
|
|||
TestCheckpointingWithStatefulNodes(DeviceDescriptor::GPUDevice(0));
|
||||
}
|
||||
|
||||
|
||||
BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesAndExplicitSeedsOnCPU)
|
||||
{
|
||||
TestCheckpointingWithStatefulNodesAndExplicitSeeds(DeviceDescriptor::CPUDevice());
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesAndExplicitSeedsOnGPU)
|
||||
{
|
||||
if (ShouldRunOnGpu())
|
||||
TestCheckpointingWithStatefulNodesAndExplicitSeeds(DeviceDescriptor::GPUDevice(0));
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE_END()
|
||||
|
||||
}}
|
||||
|
|
|
@ -283,7 +283,7 @@
|
|||
" with C.layers.default_options(initial_state = 0.1):\n",
|
||||
" m = C.layers.Recurrence(C.layers.LSTM(N))(x)\n",
|
||||
" m = C.ops.sequence.last(m)\n",
|
||||
" m = C.layers.Dropout(0.2)(m)\n",
|
||||
" m = C.layers.Dropout(0.2, seed=1)(m)\n",
|
||||
" m = cntk.layers.Dense(1)(m)\n",
|
||||
" return m"
|
||||
]
|
||||
|
|
|
@ -573,6 +573,9 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
%ignore CNTK::Dictionary::Keys;
|
||||
|
||||
%extend CNTK::Dictionary {
|
||||
PyObject* __getitem__(const wchar_t* key) {
|
||||
PyObject *DictionaryValueToPy(const CNTK::DictionaryValue&);
|
||||
|
|
|
@ -14,6 +14,7 @@ from ..variables import Variable, Record, Constant
|
|||
from ..ops import parameter, input, placeholder, combine
|
||||
from ..ops import times, element_times, convolution, convolution_transpose, pooling, unpooling, batch_normalization, dropout, splice, reshape, sequence, softmax, tanh, reduce_sum, reduce_mean, sqrt
|
||||
from cntk.internal import _as_tuple
|
||||
from cntk.cntk_py import sentinel_value_for_auto_select_random_seed as SentinelValueForAutoSelectRandomSeed
|
||||
from .blocks import *
|
||||
from .higher_order_layers import *
|
||||
from .blocks import _initializer_for, _get_initial_state_or_default, _INFERRED # helpers
|
||||
|
@ -1023,7 +1024,10 @@ def MaxUnpooling(filter_shape, # shape of receptive field, e.g. (3,3)
|
|||
|
||||
|
||||
# TODO: should the rate(s) be default_options?
|
||||
def Dropout(dropout_rate=None, keep_prob=None, name=''):
|
||||
def Dropout(dropout_rate=None,
|
||||
keep_prob=None,
|
||||
seed = SentinelValueForAutoSelectRandomSeed,
|
||||
name=''):
|
||||
'''
|
||||
Layer factory function to create a drop-out layer.
|
||||
|
||||
|
@ -1043,6 +1047,7 @@ def Dropout(dropout_rate=None, keep_prob=None, name=''):
|
|||
Args:
|
||||
dropout_rate (float): probability of dropping out an element, mutually exclusive with ``keep_prob``
|
||||
keep_prob (float): probability of keeping an element, mutually exclusive with ``dropout_rate``
|
||||
seed (int): random seed.
|
||||
name (str, defaults to ''): the name of the function instance in the network
|
||||
|
||||
Returns:
|
||||
|
@ -1057,7 +1062,7 @@ def Dropout(dropout_rate=None, keep_prob=None, name=''):
|
|||
dropout_rate = 1-keep_prob
|
||||
@BlockFunction('Dropout', name)
|
||||
def dropout_f(x):
|
||||
return dropout(x, dropout_rate=dropout_rate)
|
||||
return dropout(x, dropout_rate=dropout_rate, seed=seed)
|
||||
return dropout_f
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from cntk.internal import sanitize_input, sanitize_shape, sanitize_axis, sanitiz
|
|||
from cntk.internal.utils import get_data_type
|
||||
from ..axis import Axis
|
||||
from .. import cntk_py
|
||||
from ..cntk_py import sentinel_value_for_auto_select_random_seed as SentinelValueForAutoSelectRandomSeed
|
||||
from ..default_options import get_default_override, default_override_or
|
||||
|
||||
TIMES_NO_INFERRED_INPUT_RANK = cntk_py.TimesNoInferredInputRank
|
||||
|
@ -2383,7 +2384,12 @@ def argmin(x, axis=None, name=''):
|
|||
#######################################################################
|
||||
|
||||
@typemap
|
||||
def random_sample(weights, num_samples, allow_duplicates, name=''):
|
||||
def random_sample(
|
||||
weights,
|
||||
num_samples,
|
||||
allow_duplicates,
|
||||
seed = SentinelValueForAutoSelectRandomSeed,
|
||||
name=''):
|
||||
'''
|
||||
Estimates inclusion frequencies for random sampling with or without
|
||||
replacement.
|
||||
|
@ -2404,6 +2410,8 @@ def random_sample(weights, num_samples, allow_duplicates, name=''):
|
|||
num_samples (int): number of expected samples
|
||||
allow_duplicates (bool): If sampling is done
|
||||
with replacement (`True`) or without (`False`).
|
||||
seed (int): random seed.
|
||||
name (:class:`str`, optional): the name of the Function instance in the network.
|
||||
|
||||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
|
@ -2412,14 +2420,15 @@ def random_sample(weights, num_samples, allow_duplicates, name=''):
|
|||
from cntk.cntk_py import random_sample
|
||||
weights = sanitize_input(weights)
|
||||
|
||||
return random_sample(weights, num_samples, allow_duplicates, name)
|
||||
return random_sample(weights, num_samples, allow_duplicates, seed, name)
|
||||
|
||||
|
||||
@typemap
|
||||
def random_sample_inclusion_frequency(
|
||||
weights,
|
||||
num_samples,
|
||||
allow_duplicates,
|
||||
allow_duplicates,
|
||||
seed = SentinelValueForAutoSelectRandomSeed,
|
||||
name=''):
|
||||
'''
|
||||
For weighted sampling with the specifed sample size (`num_samples`)
|
||||
|
@ -2438,6 +2447,8 @@ def random_sample_inclusion_frequency(
|
|||
num_samples (int): number of expected samples
|
||||
allow_duplicates (bool): If sampling is done
|
||||
with replacement (`True`) or without (`False`).
|
||||
seed (int): random seed.
|
||||
name (:class:`str`, optional): the name of the Function instance in the network.
|
||||
|
||||
Example:
|
||||
>>> import numpy as np
|
||||
|
@ -2470,11 +2481,12 @@ def random_sample_inclusion_frequency(
|
|||
weights,
|
||||
num_samples,
|
||||
allow_duplicates,
|
||||
seed,
|
||||
name)
|
||||
|
||||
|
||||
@typemap
|
||||
def dropout(x, dropout_rate=0.0, name=''):
|
||||
def dropout(x, dropout_rate=0.0, seed = SentinelValueForAutoSelectRandomSeed, name=''):
|
||||
'''
|
||||
Each element of the input is independently set to 0 with probabily ``dropout_rate``
|
||||
or to 1 / (1 - ``dropout_rate``) times its original value (with probability 1-``dropout_rate``).
|
||||
|
@ -2500,6 +2512,7 @@ def dropout(x, dropout_rate=0.0, name=''):
|
|||
Args:
|
||||
x: input tensor
|
||||
dropout_rate (float, [0,1)): probability that an element of ``x`` will be set to zero
|
||||
seed (int): random seed.
|
||||
name (:class:`str`, optional): the name of the Function instance in the network
|
||||
|
||||
Returns:
|
||||
|
@ -2511,7 +2524,7 @@ def dropout(x, dropout_rate=0.0, name=''):
|
|||
from cntk.cntk_py import dropout
|
||||
x = sanitize_input(x)
|
||||
|
||||
return dropout(x, dropout_rate, name)
|
||||
return dropout(x, dropout_rate, seed, name)
|
||||
|
||||
##########################################################################
|
||||
# variables_and_parameters ops
|
||||
|
|
|
@ -186,6 +186,41 @@ def test_op_dropout(shape, dropout_rate, device_id, precision):
|
|||
assert(abs(resulted_non_zeros - expected_non_zeros) <
|
||||
max_off)
|
||||
|
||||
def test_op_dropout_with_explicit_seed(device_id, precision):
|
||||
from cntk import combine, dropout, input
|
||||
|
||||
value = np.ones(shape=(10,10), dtype=PRECISION_TO_TYPE[precision])
|
||||
|
||||
a = input(shape=value.shape,
|
||||
dtype=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
|
||||
needs_gradient=True,
|
||||
name='a')
|
||||
|
||||
seed = 123;
|
||||
|
||||
dropout_nodes= [
|
||||
dropout(a, dropout_rate=0.5, seed=seed),
|
||||
dropout(a, dropout_rate=0.5, seed=seed),
|
||||
dropout(a, dropout_rate=0.5, seed=seed+1),
|
||||
dropout(a, dropout_rate=0.5)
|
||||
]
|
||||
|
||||
value.shape = (1, 1) + value.shape
|
||||
forward_input = {a: value}
|
||||
results = []
|
||||
for node in dropout_nodes:
|
||||
forward, backward = cntk_eval(node,
|
||||
forward_input,
|
||||
precision,
|
||||
cntk_device(device_id),
|
||||
backward_pass=True)
|
||||
|
||||
results.append(forward[node.output])
|
||||
|
||||
assert np.allclose(results[0], results[1])
|
||||
assert not np.allclose(results[0], results[2])
|
||||
assert not np.allclose(results[0], results[3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dropout_rate", [-0.1, 1.0, 100])
|
||||
def test_op_dropout_bad_input(dropout_rate):
|
||||
|
|
|
@ -114,3 +114,18 @@ def test_random_sample_without_replacement(weights, num_samples, expected_count,
|
|||
denseResult = times(result, identity)
|
||||
observed_count = np.sum(denseResult.eval(), 0)
|
||||
assert np.allclose(observed_count, expected_count, atol=tolerance)
|
||||
|
||||
def test_random_sample_with_explicit_seed(device_id, precision):
|
||||
weights = AA([x for x in range(0, 10)], precision)
|
||||
identity = np.identity(weights.size)
|
||||
allow_duplicates = False # sample without replacement
|
||||
num_samples = 5;
|
||||
seed = 123
|
||||
to_dense = lambda x: times(x, identity).eval()
|
||||
result1 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed))
|
||||
result2 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed))
|
||||
result3 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed+1))
|
||||
result4 = to_dense(random_sample(weights, num_samples, allow_duplicates))
|
||||
assert np.allclose(result1, result2)
|
||||
assert not np.allclose(result1, result3)
|
||||
assert not np.allclose(result1, result4)
|
||||
|
|
|
@ -62,9 +62,9 @@ def test_convolution_transpose_attributes():
|
|||
|
||||
def test_dropout_attributes():
|
||||
x = C.input( (1, 5, 5) )
|
||||
f = C.dropout(x, 0.5)
|
||||
f = C.dropout(x, 0.5, 42)
|
||||
d = f.root_function.attributes
|
||||
expected = {'dropoutRate': 0.5}
|
||||
expected = {'dropoutRate': 0.5, 'rngSeed' : 42, 'rngOffset' : 0}
|
||||
_check(expected, d)
|
||||
|
||||
def test_slice_attributes():
|
||||
|
|
Загрузка…
Ссылка в новой задаче