Integrate alrezni/v2_dropout into master
This commit is contained in:
Коммит
94993f3c81
|
@ -1531,7 +1531,15 @@ namespace CNTK
|
||||||
DictionaryIterator end() const { return m_dictionaryData->end(); }
|
DictionaryIterator end() const { return m_dictionaryData->end(); }
|
||||||
ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); }
|
ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); }
|
||||||
|
|
||||||
size_t Size() { return m_dictionaryData->size(); }
|
size_t Size() const { return m_dictionaryData->size(); }
|
||||||
|
|
||||||
|
std::unordered_set<std::wstring> Keys()
|
||||||
|
{
|
||||||
|
std::unordered_set<std::wstring> keys;
|
||||||
|
for (const auto& kv : *m_dictionaryData)
|
||||||
|
keys.insert(kv.first);
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
friend CNTK_API std::istream& operator>>(std::istream& stream, Dictionary& us);
|
friend CNTK_API std::istream& operator>>(std::istream& stream, Dictionary& us);
|
||||||
friend CNTK_API std::ostream& operator<<(std::ostream& stream, const Dictionary& us);
|
friend CNTK_API std::ostream& operator<<(std::ostream& stream, const Dictionary& us);
|
||||||
|
@ -3391,20 +3399,17 @@ namespace CNTK
|
||||||
///
|
///
|
||||||
/// Create an instance of the random_sample operation on specified sampling weights input vector
|
/// Create an instance of the random_sample operation on specified sampling weights input vector
|
||||||
///
|
///
|
||||||
// TODO: The initial random seed should be specifiable
|
CNTK_API FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
|
||||||
CNTK_API FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name /*= L""*/);
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Create an instance of the random_sample_inclusion_frequency operation on specified sampling weights input vector
|
/// Create an instance of the random_sample_inclusion_frequency operation on specified sampling weights input vector
|
||||||
///
|
///
|
||||||
// TODO: The initial random seed should be specifiable
|
CNTK_API FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
|
||||||
CNTK_API FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name /*= L""*/);
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Create an instance of the dropout operation on specified tensor input operand
|
/// Create an instance of the dropout operation on specified tensor input operand
|
||||||
///
|
///
|
||||||
// TODO: The initial random seed should be specifiable
|
CNTK_API FunctionPtr Dropout(const Variable& operand, double dropoutRate, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
|
||||||
CNTK_API FunctionPtr Dropout(const Variable& operand, double dropoutRate, const std::wstring& name = L"");
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Create an instance of the reshape operation on specified tensor input operand
|
/// Create an instance of the reshape operation on specified tensor input operand
|
||||||
|
@ -4605,7 +4610,8 @@ namespace CNTK
|
||||||
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
||||||
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
||||||
|
|
||||||
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState);
|
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState,
|
||||||
|
const Dictionary& externalState, const Dictionary& distributedState = {});
|
||||||
|
|
||||||
void UpdateTrainingProgress(size_t numSamples, const ValuePtr& loss, const ValuePtr& evalCriterion, const DeviceDescriptor& computeDevice);
|
void UpdateTrainingProgress(size_t numSamples, const ValuePtr& loss, const ValuePtr& evalCriterion, const DeviceDescriptor& computeDevice);
|
||||||
void AddProgressWriters(const std::vector<ProgressWriterPtr>& progressWriters);
|
void AddProgressWriters(const std::vector<ProgressWriterPtr>& progressWriters);
|
||||||
|
|
|
@ -239,6 +239,8 @@ namespace CNTK
|
||||||
|
|
||||||
CNTK_API size_t NewUniqueId();
|
CNTK_API size_t NewUniqueId();
|
||||||
|
|
||||||
|
CNTK_API size_t GenerateRandomSeed();
|
||||||
|
|
||||||
// Internal hooks for testing and higher-level bindings
|
// Internal hooks for testing and higher-level bindings
|
||||||
// These should not be directly called by C++ API users
|
// These should not be directly called by C++ API users
|
||||||
CNTK_API void EnableReversingTensorShapesInErrorMessages();
|
CNTK_API void EnableReversingTensorShapesInErrorMessages();
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
#include "stdafx.h"
|
#include "stdafx.h"
|
||||||
#include "CNTKLibrary.h"
|
#include "CNTKLibrary.h"
|
||||||
#include "PrimitiveFunction.h"
|
#include "PrimitiveFunction.h"
|
||||||
|
#include "Utils.h"
|
||||||
|
#include "Variable.h"
|
||||||
|
|
||||||
namespace CNTK
|
namespace CNTK
|
||||||
{
|
{
|
||||||
|
|
|
@ -193,4 +193,4 @@
|
||||||
<Warning Condition="!$(HasProtobuf)" Text="CNTKv2LibraryDll requires Protocol Buffers to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#protobuf for installation instructions." />
|
<Warning Condition="!$(HasProtobuf)" Text="CNTKv2LibraryDll requires Protocol Buffers to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#protobuf for installation instructions." />
|
||||||
<Error Condition="!$(HasBoost)" Text="CNTKv2LibraryDll requires the Boost library to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#boost for installation instructions." />
|
<Error Condition="!$(HasBoost)" Text="CNTKv2LibraryDll requires the Boost library to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#boost for installation instructions." />
|
||||||
</Target>
|
</Target>
|
||||||
</Project>
|
</Project>
|
|
@ -28,12 +28,36 @@ namespace CNTK
|
||||||
{
|
{
|
||||||
namespace Internal
|
namespace Internal
|
||||||
{
|
{
|
||||||
static std::atomic<unsigned long long> s_nextUniqueId(0);
|
static std::atomic_ullong s_nextUniqueId = ATOMIC_VAR_INIT(0);
|
||||||
size_t NewUniqueId()
|
size_t NewUniqueId()
|
||||||
{
|
{
|
||||||
return s_nextUniqueId++;
|
return s_nextUniqueId++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0);
|
||||||
|
|
||||||
|
// This is used to generate a default seed for stateful nodes (dropout, and both
|
||||||
|
// flavors of random sample). As a result, in distributed environment, each worker
|
||||||
|
// ends up having a different seed.
|
||||||
|
|
||||||
|
size_t GenerateRandomSeed()
|
||||||
|
{
|
||||||
|
static size_t numWorkers = 1, rank = 0;
|
||||||
|
static bool initialized = false;
|
||||||
|
if (MPIWrapper::GetTotalNumberOfMPINodes() != 0 && !initialized)
|
||||||
|
{
|
||||||
|
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||||
|
numWorkers = communicator->Workers().size();
|
||||||
|
rank = communicator->CurrentWorker().m_globalRank;
|
||||||
|
|
||||||
|
if (numWorkers < 1)
|
||||||
|
numWorkers = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
initialized = true;
|
||||||
|
return (numWorkers * s_currentRandomSeed++) + rank;
|
||||||
|
}
|
||||||
|
|
||||||
std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
|
std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
|
||||||
void EnableReversingTensorShapesInErrorMessages()
|
void EnableReversingTensorShapesInErrorMessages()
|
||||||
{
|
{
|
||||||
|
|
|
@ -46,8 +46,58 @@ namespace CNTK
|
||||||
return dict;
|
return dict;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy the internal state from the network into the function graph,
|
||||||
|
// specifically from RngUser nodes into the attributes dictionaries of
|
||||||
|
// the corresponding stateful primitive functions.
|
||||||
|
void CompositeFunction::UpdateInternalState() const
|
||||||
|
{
|
||||||
|
if (!m_computationNetwork)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (auto& function : m_allPrimitiveFunctions)
|
||||||
|
{
|
||||||
|
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||||
|
if (!primitiveFunction->IsStateful())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// TODO: same for BatchNorm
|
||||||
|
|
||||||
|
auto& outputs = primitiveFunction->RawOutputs();
|
||||||
|
if (outputs.size() != 1)
|
||||||
|
LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());
|
||||||
|
|
||||||
|
const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();
|
||||||
|
|
||||||
|
Dictionary state;
|
||||||
|
state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
|
||||||
|
state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
|
||||||
|
primitiveFunction->SetState(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a dictionary representing the internal (local) state of the function graph.
|
||||||
|
Dictionary CompositeFunction::GetInternalState() const
|
||||||
|
{
|
||||||
|
UpdateInternalState();
|
||||||
|
|
||||||
|
Dictionary stateDictionary;
|
||||||
|
for (auto& function : m_allPrimitiveFunctions)
|
||||||
|
{
|
||||||
|
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
|
||||||
|
if (!primitiveFunction->IsStateful())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// TODO: same for BatchNorm
|
||||||
|
|
||||||
|
stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
|
||||||
|
}
|
||||||
|
return stateDictionary;
|
||||||
|
}
|
||||||
|
|
||||||
/*virtual*/ Dictionary CompositeFunction::Serialize() const
|
/*virtual*/ Dictionary CompositeFunction::Serialize() const
|
||||||
{
|
{
|
||||||
|
UpdateInternalState();
|
||||||
|
|
||||||
Dictionary dict = SerializeBlockComposite();
|
Dictionary dict = SerializeBlockComposite();
|
||||||
|
|
||||||
// Find cycles in the graph and "break" them by inserting placeholders.
|
// Find cycles in the graph and "break" them by inserting placeholders.
|
||||||
|
@ -129,29 +179,6 @@ namespace CNTK
|
||||||
}
|
}
|
||||||
|
|
||||||
dict[functionsKey] = std::move(functionDictionaries);
|
dict[functionsKey] = std::move(functionDictionaries);
|
||||||
|
|
||||||
// Now, collect and store the internal state for all non-pure (stateful) functions in the graph
|
|
||||||
// (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
|
|
||||||
Dictionary stateDictionary;
|
|
||||||
for (const auto& kv : m_variableToNodeMap)
|
|
||||||
{
|
|
||||||
if (kv.second->Is<RngUser>() && kv.first.IsOutput())
|
|
||||||
{
|
|
||||||
// The RNG state should be associated with the actual function that the computation node
|
|
||||||
// corresponds to, and not the block primitives that wrap the actual function
|
|
||||||
auto ownerFunction = kv.first.Owner().get();
|
|
||||||
if (!ownerFunction->IsBlock())
|
|
||||||
{
|
|
||||||
auto rng = kv.second->As<RngUser>();
|
|
||||||
Dictionary state;
|
|
||||||
state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
|
|
||||||
state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
|
|
||||||
stateDictionary[ownerFunction->Uid()] = state;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dict[stateKey] = std::move(stateDictionary);
|
|
||||||
|
|
||||||
return dict;
|
return dict;
|
||||||
}
|
}
|
||||||
|
@ -217,10 +244,6 @@ namespace CNTK
|
||||||
uidToInputMap[inputVar.Uid()] = inputVar;
|
uidToInputMap[inputVar.Uid()] = inputVar;
|
||||||
}
|
}
|
||||||
|
|
||||||
Dictionary stateDictionary;
|
|
||||||
if (dict.Contains(stateKey))
|
|
||||||
stateDictionary = dict[stateKey].Value<Dictionary>();
|
|
||||||
|
|
||||||
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
|
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
|
||||||
|
|
||||||
std::unordered_map<Variable, Variable> allPlaceholderReplacements;
|
std::unordered_map<Variable, Variable> allPlaceholderReplacements;
|
||||||
|
@ -238,25 +261,6 @@ namespace CNTK
|
||||||
if (opType == PrimitiveOpType::Combine)
|
if (opType == PrimitiveOpType::Combine)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (primitiveFunction->IsStateful())
|
|
||||||
{
|
|
||||||
if (stateDictionary.Contains(primitiveFunction->Uid()))
|
|
||||||
{
|
|
||||||
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
|
|
||||||
auto seed = state[rngSeedKey].Value<size_t>();
|
|
||||||
auto offset = state[rngOffsetKey].Value<size_t>();
|
|
||||||
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
|
|
||||||
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
|
|
||||||
}
|
|
||||||
else if (Internal::GetComputationNetworkTraceLevel() > 0)
|
|
||||||
{
|
|
||||||
// TODO: all logging functionality should be refactored to live in a logging utility class.
|
|
||||||
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
|
|
||||||
"when deserializing from a dictionary (version=%zu). "
|
|
||||||
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto& output : root->RawOutputs())
|
for (const auto& output : root->RawOutputs())
|
||||||
{
|
{
|
||||||
const auto& it = uidToInputMap.find(output.Uid());
|
const auto& it = uidToInputMap.find(output.Uid());
|
||||||
|
@ -276,63 +280,122 @@ namespace CNTK
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
|
||||||
|
// corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
|
||||||
|
if (version < 3)
|
||||||
|
RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);
|
||||||
|
|
||||||
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
|
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions)
|
||||||
|
{
|
||||||
|
Dictionary stateDictionary;
|
||||||
|
if (dict.Contains(stateKey))
|
||||||
|
stateDictionary = dict[stateKey].Value<Dictionary>();
|
||||||
|
|
||||||
|
for (auto& function : functions)
|
||||||
|
{
|
||||||
|
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||||
|
if (!primitiveFunction->IsStateful())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (stateDictionary.Contains(primitiveFunction->Uid()))
|
||||||
|
{
|
||||||
|
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
|
||||||
|
// Add key-value pairs expected by the SetState method to the state dictionary.
|
||||||
|
state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
|
||||||
|
state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
|
||||||
|
primitiveFunction->SetState(state);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (Internal::GetComputationNetworkTraceLevel() > 0) {
|
||||||
|
// TODO: all logging functionality should be refactored to live in a logging utility class.
|
||||||
|
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
|
||||||
|
"when deserializing from a dictionary (version=%zu). "
|
||||||
|
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create state from scratch, so that function attributes contain all the required key-value pairs.
|
||||||
|
Dictionary state;
|
||||||
|
state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed();
|
||||||
|
state[PrimitiveFunction::AttributeNameRngOffset] = 0;
|
||||||
|
primitiveFunction->SetState(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CompositeFunction::CopyState(const CompositeFunction& source)
|
void CompositeFunction::CopyState(const CompositeFunction& source)
|
||||||
{
|
{
|
||||||
// Create a map with all non-pure (stateful) functions in the function graph.
|
// Collect a vector of stateful funciton uids using a pre-order traversal of a function graphs.
|
||||||
auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
|
auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
|
||||||
std::map<std::wstring, FunctionPtr> functionMap;
|
vector<wstring> uids;
|
||||||
for (auto funcPtr : allPrimitiveFunctions)
|
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
|
||||||
{
|
|
||||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
|
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
|
||||||
if (primitiveFunction->IsStateful())
|
if (primitiveFunction->IsStateful())
|
||||||
{
|
{
|
||||||
functionMap[primitiveFunction->Uid()] = funcPtr;
|
uids.push_back(funcPtr->Uid());
|
||||||
}
|
}
|
||||||
}
|
}, true);
|
||||||
return functionMap;
|
|
||||||
|
return uids;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
|
auto theirUIDs = collectStatefulFunctionUIDs(source);
|
||||||
std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);
|
auto ourUIDs = collectStatefulFunctionUIDs(*this);
|
||||||
|
|
||||||
assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
|
if (theirUIDs.size() != ourUIDs.size())
|
||||||
if (statefulFunctionsFrom.size() == 0)
|
CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");
|
||||||
|
|
||||||
|
auto state = source.GetInternalState();
|
||||||
|
|
||||||
|
if (theirUIDs == ourUIDs)
|
||||||
{
|
{
|
||||||
|
// uids are identialy, no need to remap.
|
||||||
|
SetInternalState(state);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// build a map of souce funtion to the destination (this) function UIDs.
|
||||||
|
map<wstring, wstring> uidMap;
|
||||||
|
for (auto i = 0; i < theirUIDs.size(); i++)
|
||||||
|
uidMap[theirUIDs[i]] = ourUIDs[i];
|
||||||
|
|
||||||
|
Dictionary remappedState;
|
||||||
|
for (auto& kv : state)
|
||||||
|
remappedState[uidMap[kv.first]] = kv.second;
|
||||||
|
|
||||||
// Copy state captured in the attributes dictionaries.
|
SetInternalState(remappedState);
|
||||||
for (const auto& kv : statefulFunctionsFrom)
|
|
||||||
{
|
|
||||||
statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateInternalNetworkState();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompositeFunction::UpdateInternalNetworkState()
|
void CompositeFunction::SetInternalState(const Dictionary& state)
|
||||||
{
|
{
|
||||||
if (!m_computationNetwork)
|
if (state.Size() == 0)
|
||||||
{
|
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto& function : m_allPrimitiveFunctions)
|
for (const auto& function : m_allPrimitiveFunctions)
|
||||||
{
|
{
|
||||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
|
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||||
if (primitiveFunction->IsStateful())
|
if (!primitiveFunction->IsStateful())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
|
||||||
|
|
||||||
|
primitiveFunction->SetState(functionState);
|
||||||
|
|
||||||
|
if (!m_computationNetwork)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||||
|
auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||||
|
|
||||||
|
// copy the state directly into the network
|
||||||
|
for (const auto& output : function->RawOutputs())
|
||||||
{
|
{
|
||||||
for (const auto& output : function->RawOutputs())
|
auto node = m_variableToNodeMap.at(output);
|
||||||
{
|
node->As<RngUser>()->SetRngState(seed, offset);
|
||||||
auto node = m_variableToNodeMap.at(output);
|
|
||||||
auto attributes = function->Attributes();
|
|
||||||
auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
|
||||||
auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
|
||||||
node->As<RngUser>()->SetRngState(seed, offset);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -895,16 +958,9 @@ namespace CNTK
|
||||||
|
|
||||||
if (computationNodePtr->Is<RngUser>())
|
if (computationNodePtr->Is<RngUser>())
|
||||||
{
|
{
|
||||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
|
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||||
{
|
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||||
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
|
||||||
uint64_t offset = 0;
|
|
||||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
|
|
||||||
{
|
|
||||||
offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
|
||||||
}
|
|
||||||
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
|
|
@ -110,7 +110,7 @@ namespace CNTK
|
||||||
Dictionary SerializeBlockComposite() const;
|
Dictionary SerializeBlockComposite() const;
|
||||||
|
|
||||||
virtual Dictionary Serialize() const override;
|
virtual Dictionary Serialize() const override;
|
||||||
|
|
||||||
virtual size_t CurrentVersion() const override { return s_serializationVersion; }
|
virtual size_t CurrentVersion() const override { return s_serializationVersion; }
|
||||||
|
|
||||||
static FunctionPtr DeserializeBlockComposite(const Dictionary& dict,
|
static FunctionPtr DeserializeBlockComposite(const Dictionary& dict,
|
||||||
|
@ -238,12 +238,26 @@ namespace CNTK
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the network is already created, copy internal state over from the functions in the graph into the underlying network.
|
|
||||||
void UpdateInternalNetworkState();
|
|
||||||
|
|
||||||
// Copy state info from source function graph into' this' function graph.
|
// Copy the internal state from the network into the function graph.
|
||||||
|
void UpdateInternalState() const;
|
||||||
|
|
||||||
|
// Generate a dictionary representing the internal (local) state of the function graph.
|
||||||
|
Dictionary GetInternalState() const;
|
||||||
|
|
||||||
|
// Update the internal state using the provided dictionary.
|
||||||
|
// If the network is already created, directly update its state. Otherwise, copy the state from the
|
||||||
|
// dictionary into the function graph.
|
||||||
|
void SetInternalState(const Dictionary& state);
|
||||||
|
|
||||||
|
// Copy state info from source function graph into 'this' function graph.
|
||||||
|
// Both graphs must be equivalent.
|
||||||
void CopyState(const CompositeFunction& source);
|
void CopyState(const CompositeFunction& source);
|
||||||
|
|
||||||
|
// This function is only needed for backwards compatibility to support deserializing composite funcitions that
|
||||||
|
// stored the internal state inside a dedicated value in the dictionary.
|
||||||
|
static void RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> PrimitiveFunctions);
|
||||||
|
|
||||||
static Variable GetMappingForNoOpOutput(const Variable& variable, bool recursive = false);
|
static Variable GetMappingForNoOpOutput(const Variable& variable, bool recursive = false);
|
||||||
static Variable GetMappingVariable(const Variable& variable, bool recursive = false);
|
static Variable GetMappingVariable(const Variable& variable, bool recursive = false);
|
||||||
|
|
||||||
|
@ -328,6 +342,7 @@ namespace CNTK
|
||||||
// Version history:
|
// Version history:
|
||||||
// 1 -- initial version.
|
// 1 -- initial version.
|
||||||
// 2 -- add support for stateful functions (with corresponding nodes inheriting from RngUser).
|
// 2 -- add support for stateful functions (with corresponding nodes inheriting from RngUser).
|
||||||
static const size_t s_serializationVersion = 2;
|
// 3 -- store internal function state directly in the attributes dictionary.
|
||||||
|
static const size_t s_serializationVersion = 3;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#include "PrimitiveFunction.h"
|
#include "PrimitiveFunction.h"
|
||||||
#include "CompositeFunction.h"
|
#include "CompositeFunction.h"
|
||||||
#include "BlockFunction.h"
|
#include "BlockFunction.h"
|
||||||
|
#include "Utils.h"
|
||||||
|
|
||||||
using namespace Microsoft::MSR::CNTK;
|
using namespace Microsoft::MSR::CNTK;
|
||||||
|
|
||||||
|
@ -1037,29 +1038,47 @@ namespace CNTK
|
||||||
LogicError("Slice: Invalid axis argument provided. Slice along the dynamic batch axis is currently unsupported. To slice a sequence along its ordered dynamic axis use Sequence::Slice.");
|
LogicError("Slice: Invalid axis argument provided. Slice along the dynamic batch axis is currently unsupported. To slice a sequence along its ordered dynamic axis use Sequence::Slice.");
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name)
|
FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed, const std::wstring& name)
|
||||||
{
|
{
|
||||||
auto additionalProperties = Dictionary();
|
auto additionalProperties = Dictionary();
|
||||||
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
||||||
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
||||||
|
|
||||||
|
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||||
|
seed = Internal::GenerateRandomSeed();
|
||||||
|
|
||||||
|
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||||
|
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
|
||||||
|
|
||||||
return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name);
|
return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name)
|
FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed, const std::wstring& name)
|
||||||
{
|
{
|
||||||
auto additionalProperties = Dictionary();
|
auto additionalProperties = Dictionary();
|
||||||
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
||||||
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
||||||
|
|
||||||
|
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||||
|
seed = Internal::GenerateRandomSeed();
|
||||||
|
|
||||||
|
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||||
|
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
|
||||||
|
|
||||||
return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name);
|
return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr Dropout(const Variable& operand, double dropoutRate, const std::wstring& name)
|
FunctionPtr Dropout(const Variable& operand, double dropoutRate, unsigned long seed, const std::wstring& name)
|
||||||
{
|
{
|
||||||
auto additionalProperties = Dictionary();
|
auto additionalProperties = Dictionary();
|
||||||
additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate;
|
additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate;
|
||||||
|
|
||||||
|
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||||
|
seed = Internal::GenerateRandomSeed();
|
||||||
|
|
||||||
|
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||||
|
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
|
||||||
|
|
||||||
return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name);
|
return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,9 @@
|
||||||
#include "BlockFunction.h"
|
#include "BlockFunction.h"
|
||||||
#include "CompositeFunction.h"
|
#include "CompositeFunction.h"
|
||||||
#include "SpecialPurposeNodes.h"
|
#include "SpecialPurposeNodes.h"
|
||||||
|
#include "ConvolveGeometry.h"
|
||||||
|
#include "ConvolutionalNodes.h"
|
||||||
|
#include "Variable.h"
|
||||||
|
|
||||||
using namespace Microsoft::MSR::CNTK;
|
using namespace Microsoft::MSR::CNTK;
|
||||||
|
|
||||||
|
@ -37,7 +40,7 @@ namespace CNTK
|
||||||
|
|
||||||
// Names of the various attributes of CNTK primitive Functions
|
// Names of the various attributes of CNTK primitive Functions
|
||||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis";
|
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis";
|
||||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxisVec = L"axisVec";
|
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxisVec = L"axisVec";
|
||||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1";
|
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1";
|
||||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2";
|
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2";
|
||||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates";
|
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates";
|
||||||
|
@ -976,4 +979,118 @@ namespace CNTK
|
||||||
return std::shared_ptr<PrimitiveFunction>(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid),
|
return std::shared_ptr<PrimitiveFunction>(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid),
|
||||||
[](PrimitiveFunction* ptr) { delete ptr; });
|
[](PrimitiveFunction* ptr) { delete ptr; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const vector<wstring> s_stateAttributes = { PrimitiveFunction::AttributeNameRngSeed, PrimitiveFunction::AttributeNameRngOffset };
|
||||||
|
|
||||||
|
Dictionary PrimitiveFunction::GetState() const
|
||||||
|
{
|
||||||
|
if (!IsStateful())
|
||||||
|
LogicError("Function '%S' is not stateful.", AsString().c_str());
|
||||||
|
|
||||||
|
Dictionary state;
|
||||||
|
for (auto& key : s_stateAttributes)
|
||||||
|
{
|
||||||
|
state[key] = m_attributes[key];
|
||||||
|
}
|
||||||
|
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrimitiveFunction::SetState(const Dictionary& state)
|
||||||
|
{
|
||||||
|
if (!IsStateful())
|
||||||
|
LogicError("Function '%S' is not stateful.", AsString().c_str());
|
||||||
|
|
||||||
|
for (auto& key : s_stateAttributes)
|
||||||
|
{
|
||||||
|
m_attributes[key] = state[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ void PrimitiveFunction::FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from/* = NDShape()*/)
|
||||||
|
{
|
||||||
|
auto dims = shape.Dimensions();
|
||||||
|
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(filterRank, inputRank, dims, deflt, from.Dimensions());
|
||||||
|
shape = NDShape(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
NDShape PrimitiveFunction::ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
|
||||||
|
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
|
||||||
|
bool transpose, bool inferDimensions, bool ceilOutputDim/* = false*/) const
|
||||||
|
{
|
||||||
|
if (inferDimensions)
|
||||||
|
{
|
||||||
|
size_t inputRank = operandShape.Rank();
|
||||||
|
|
||||||
|
// Unknown kernel shape valid only for pooling, however, the shape should have expanded before
|
||||||
|
// this call.
|
||||||
|
if (kernelShape == NDShape::Unknown)
|
||||||
|
RuntimeError("Convolution: Kernel shape can't be Unknown.");
|
||||||
|
|
||||||
|
// infer reduction dimensions if not given
|
||||||
|
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
|
||||||
|
size_t filterRank = kernelShape.Rank();
|
||||||
|
|
||||||
|
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
|
||||||
|
// picking the corresponding operand shape dimensionality
|
||||||
|
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
|
||||||
|
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
|
||||||
|
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
|
||||||
|
{
|
||||||
|
filterRank--;
|
||||||
|
kernelShape = kernelShape.SubShape(0, filterRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
NDShape fromShape;
|
||||||
|
if (op == PrimitiveOpType::Convolution)
|
||||||
|
fromShape = operandShape;
|
||||||
|
|
||||||
|
size_t fillRank = (!transpose) ? filterRank : filterRank - 1;
|
||||||
|
FixNDShape(fillRank, inputRank, kernelShape, 1, fromShape); // convolve over red dim; pool over 1
|
||||||
|
FixNDShape(fillRank, inputRank, strides, 1, fromShape); // stride for reduction dims is red dim or 1
|
||||||
|
FixNDShape(fillRank, inputRank, lowerPad, 0);
|
||||||
|
FixNDShape(fillRank, inputRank, upperPad, 0);
|
||||||
|
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, sharing, true);
|
||||||
|
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, autoPad, false); // no padding for reduction dims
|
||||||
|
}
|
||||||
|
|
||||||
|
decltype(&Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape) computeOutputShapeFunc;
|
||||||
|
if (!transpose)
|
||||||
|
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape;
|
||||||
|
else
|
||||||
|
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
|
||||||
|
|
||||||
|
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutputDim));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ bool PrimitiveFunction::UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes)
|
||||||
|
{
|
||||||
|
bool anyParameterOperandDimsInferred = false;
|
||||||
|
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
|
||||||
|
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
|
||||||
|
{
|
||||||
|
operand.m_dataFields->m_shape = newOperandShape;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& newOperandShapePair : newOperandShapes)
|
||||||
|
anyParameterOperandDimsInferred = updateOperandShapeFunc(newOperandShapePair.first, newOperandShapePair.second) || anyParameterOperandDimsInferred;
|
||||||
|
|
||||||
|
return anyParameterOperandDimsInferred;
|
||||||
|
}
|
||||||
|
|
||||||
|
NDShape PrimitiveFunction::NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const
|
||||||
|
{
|
||||||
|
assert(operands.size() > 1);
|
||||||
|
|
||||||
|
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
|
||||||
|
Variable dummyOutputVariable = PlaceholderVariable(NDShape());
|
||||||
|
for (auto& operand : operands)
|
||||||
|
dummyOutputVariable.m_dataFields->m_shape = BinaryElementwiseOpOutputShape(op, dummyOutputVariable, operand, broadcastAllowed, inferInputDimensions);
|
||||||
|
|
||||||
|
return dummyOutputVariable.Shape();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,10 +8,6 @@
|
||||||
#include "stdafx.h"
|
#include "stdafx.h"
|
||||||
#include "CNTKLibrary.h"
|
#include "CNTKLibrary.h"
|
||||||
#include "PrimitiveOpType.h"
|
#include "PrimitiveOpType.h"
|
||||||
#include "Utils.h"
|
|
||||||
#include "ConvolveGeometry.h"
|
|
||||||
#include "ConvolutionalNodes.h"
|
|
||||||
#include "Variable.h"
|
|
||||||
|
|
||||||
namespace std
|
namespace std
|
||||||
{
|
{
|
||||||
|
@ -301,6 +297,10 @@ namespace CNTK
|
||||||
(OpType() == PrimitiveOpType::RandomSampleInclusionFrequency);
|
(OpType() == PrimitiveOpType::RandomSampleInclusionFrequency);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dictionary GetState() const;
|
||||||
|
|
||||||
|
void SetState(const Dictionary& state);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
// The following helper functions are used to determine the output shape for different
|
// The following helper functions are used to determine the output shape for different
|
||||||
|
@ -425,24 +425,7 @@ namespace CNTK
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a boolean indicating if any operand shape was updated
|
// Returns a boolean indicating if any operand shape was updated
|
||||||
static bool UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes)
|
static bool UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes);
|
||||||
{
|
|
||||||
bool anyParameterOperandDimsInferred = false;
|
|
||||||
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
|
|
||||||
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
|
|
||||||
{
|
|
||||||
operand.m_dataFields->m_shape = newOperandShape;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto& newOperandShapePair : newOperandShapes)
|
|
||||||
anyParameterOperandDimsInferred = updateOperandShapeFunc(newOperandShapePair.first, newOperandShapePair.second) || anyParameterOperandDimsInferred;
|
|
||||||
|
|
||||||
return anyParameterOperandDimsInferred;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
|
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
|
||||||
/*static*/ NDShape BinaryElementwiseOpOutputShape(PrimitiveOpType op, Variable& leftOperand, Variable& rightOperand, bool broadcastAllowed, bool inferInputDimensions) const
|
/*static*/ NDShape BinaryElementwiseOpOutputShape(PrimitiveOpType op, Variable& leftOperand, Variable& rightOperand, bool broadcastAllowed, bool inferInputDimensions) const
|
||||||
|
@ -493,6 +476,10 @@ namespace CNTK
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
UNUSED(broadcastAllowed);
|
||||||
|
// BUGBUG: if (broadcastAllowed) is missing here?
|
||||||
|
|
||||||
// Broadcast in remaining axes
|
// Broadcast in remaining axes
|
||||||
for (size_t i = shapeWithSmallerNumAxes.Rank(); i < numOutputAxes; ++i)
|
for (size_t i = shapeWithSmallerNumAxes.Rank(); i < numOutputAxes; ++i)
|
||||||
outputDims[i] = shapeWithLargerNumAxes[i];
|
outputDims[i] = shapeWithLargerNumAxes[i];
|
||||||
|
@ -507,17 +494,7 @@ namespace CNTK
|
||||||
return NDShape(std::move(outputDims));
|
return NDShape(std::move(outputDims));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const
|
/*static*/ NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const;
|
||||||
{
|
|
||||||
assert(operands.size() > 1);
|
|
||||||
|
|
||||||
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
|
|
||||||
Variable dummyOutputVariable = PlaceholderVariable(NDShape());
|
|
||||||
for (auto& operand : operands)
|
|
||||||
dummyOutputVariable.m_dataFields->m_shape = BinaryElementwiseOpOutputShape(op, dummyOutputVariable, operand, broadcastAllowed, inferInputDimensions);
|
|
||||||
|
|
||||||
return dummyOutputVariable.Shape();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
|
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
|
||||||
/*static*/ NDShape TimesOpOutputShape(Variable& leftOperand, Variable& rightOperand, size_t outputRank, int inferInputRankToMap, bool inferInputDimensions) const
|
/*static*/ NDShape TimesOpOutputShape(Variable& leftOperand, Variable& rightOperand, size_t outputRank, int inferInputRankToMap, bool inferInputDimensions) const
|
||||||
|
@ -643,61 +620,11 @@ namespace CNTK
|
||||||
return NDShape(std::move(outputDims));
|
return NDShape(std::move(outputDims));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from = NDShape())
|
static void FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from = NDShape());
|
||||||
{
|
|
||||||
auto dims = shape.Dimensions();
|
|
||||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(filterRank, inputRank, dims, deflt, from.Dimensions());
|
|
||||||
shape = NDShape(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ NDShape ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
|
/*static*/ NDShape ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
|
||||||
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
|
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
|
||||||
bool transpose, bool inferDimensions, bool ceilOutputDim = false) const
|
bool transpose, bool inferDimensions, bool ceilOutputDim = false) const;
|
||||||
{
|
|
||||||
if (inferDimensions)
|
|
||||||
{
|
|
||||||
size_t inputRank = operandShape.Rank();
|
|
||||||
|
|
||||||
// Unknown kernel shape valid only for pooling, however, the shape should have expanded before
|
|
||||||
// this call.
|
|
||||||
if (kernelShape == NDShape::Unknown)
|
|
||||||
RuntimeError("Convolution: Kernel shape can't be Unknown.");
|
|
||||||
|
|
||||||
// infer reduction dimensions if not given
|
|
||||||
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
|
|
||||||
size_t filterRank = kernelShape.Rank();
|
|
||||||
|
|
||||||
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
|
|
||||||
// picking the corresponding operand shape dimensionality
|
|
||||||
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
|
|
||||||
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
|
|
||||||
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
|
|
||||||
{
|
|
||||||
filterRank--;
|
|
||||||
kernelShape = kernelShape.SubShape(0, filterRank);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDShape fromShape;
|
|
||||||
if (op == PrimitiveOpType::Convolution)
|
|
||||||
fromShape = operandShape;
|
|
||||||
|
|
||||||
size_t fillRank = (!transpose)? filterRank : filterRank - 1;
|
|
||||||
FixNDShape(fillRank, inputRank, kernelShape, 1, fromShape); // convolve over red dim; pool over 1
|
|
||||||
FixNDShape(fillRank, inputRank, strides, 1, fromShape); // stride for reduction dims is red dim or 1
|
|
||||||
FixNDShape(fillRank, inputRank, lowerPad, 0);
|
|
||||||
FixNDShape(fillRank, inputRank, upperPad, 0);
|
|
||||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, sharing, true);
|
|
||||||
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, autoPad, false); // no padding for reduction dims
|
|
||||||
}
|
|
||||||
|
|
||||||
decltype(&Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape) computeOutputShapeFunc;
|
|
||||||
if (!transpose)
|
|
||||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape;
|
|
||||||
else
|
|
||||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
|
|
||||||
|
|
||||||
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutputDim));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ NDShape BatchNormalizationOutputShape(std::vector<Variable>& operands, bool spatial, bool inferDimensions) const
|
/*static*/ NDShape BatchNormalizationOutputShape(std::vector<Variable>& operands, bool spatial, bool inferDimensions) const
|
||||||
{
|
{
|
||||||
|
|
|
@ -41,6 +41,8 @@ namespace CNTK
|
||||||
const std::wstring blockFunctionOpNameKey = L"block_function_op_name";
|
const std::wstring blockFunctionOpNameKey = L"block_function_op_name";
|
||||||
const std::wstring blockFunctionCompositeArgumentsMapKeysKey = L"block_function_composite_arguments_map_keys";
|
const std::wstring blockFunctionCompositeArgumentsMapKeysKey = L"block_function_composite_arguments_map_keys";
|
||||||
const std::wstring blockFunctionCompositeArgumentsMapValuesKey = L"block_function_composite_arguments_map_values";
|
const std::wstring blockFunctionCompositeArgumentsMapValuesKey = L"block_function_composite_arguments_map_values";
|
||||||
|
const std::wstring internalWorkerStateKey = L"internal_worker_state";
|
||||||
|
const std::wstring externalWorkerStateKey = L"external_worker_state";
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion)
|
inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion)
|
||||||
|
|
|
@ -8,11 +8,21 @@
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "Learner.h"
|
#include "Learner.h"
|
||||||
#include "PerformanceProfiler.h"
|
#include "PerformanceProfiler.h"
|
||||||
|
#include "CompositeFunction.h"
|
||||||
|
#include "Serialization.h"
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
const std::wstring versionPropertyName = L"Version";
|
||||||
const std::wstring learnersPropertyName = L"Learners";
|
const std::wstring learnersPropertyName = L"Learners";
|
||||||
const std::wstring externalStatePropertyName = L"ExternalState";
|
const std::wstring externalStatePropertyName = L"ExternalState";
|
||||||
|
const std::wstring distributedStatePropertyName = L"DistributedState";
|
||||||
|
|
||||||
|
// Version history:
|
||||||
|
// 0 -- a version number before the versioning was introduced for the trainer's checkpoints.
|
||||||
|
// 1 -- initial version: added a key-value pair for the checkpoint version info, added
|
||||||
|
// distributed state key to save all local state collected from distributed workers.
|
||||||
|
static const size_t trainerCheckpointVersion = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace CNTK
|
namespace CNTK
|
||||||
|
@ -307,15 +317,22 @@ namespace CNTK
|
||||||
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState)
|
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState)
|
||||||
{
|
{
|
||||||
auto learnersState = m_parameterLearners->CreateCheckpoint();
|
auto learnersState = m_parameterLearners->CreateCheckpoint();
|
||||||
|
|
||||||
if (!m_distributed)
|
if (!m_distributed)
|
||||||
return Save(modelFilePath, learnersState, externalState);
|
return Save(modelFilePath, learnersState, externalState);
|
||||||
|
|
||||||
|
auto compositeFunction = dynamic_cast<CompositeFunction*>(m_combinedTrainingFunction.get());
|
||||||
|
|
||||||
|
Dictionary state;
|
||||||
|
state[internalWorkerStateKey] = compositeFunction->GetInternalState(); // this is the local worker's state.
|
||||||
|
state[externalWorkerStateKey] = externalState;
|
||||||
|
|
||||||
// Collect distrbuted external state.
|
// Collect distrbuted external state.
|
||||||
DistributedCommunicatorPtr communicator = MPICommunicator();
|
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||||
communicator->Barrier();
|
communicator->Barrier();
|
||||||
|
|
||||||
std::vector<DictionaryPtr> remoteState;
|
std::vector<DictionaryPtr> remoteState;
|
||||||
communicator->Gather(externalState, remoteState, communicator->Workers());
|
communicator->Gather(state, remoteState, communicator->Workers());
|
||||||
|
|
||||||
Dictionary aggregatedState;
|
Dictionary aggregatedState;
|
||||||
for (const auto& w : communicator->Workers())
|
for (const auto& w : communicator->Workers())
|
||||||
|
@ -324,19 +341,21 @@ namespace CNTK
|
||||||
}
|
}
|
||||||
|
|
||||||
if (communicator->CurrentWorker().IsMain())
|
if (communicator->CurrentWorker().IsMain())
|
||||||
Save(modelFilePath, learnersState, aggregatedState);
|
Save(modelFilePath, learnersState, externalState, aggregatedState);
|
||||||
|
|
||||||
// all workers need to sync up after saving model to avoid read-after-write hazard
|
// all workers need to sync up after saving model to avoid read-after-write hazard
|
||||||
// i.e. one worker is in the middle of write while another tries to read
|
// i.e. one worker is in the middle of write while another tries to read
|
||||||
communicator->Barrier();
|
communicator->Barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState)
|
void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState, const Dictionary& distributedState)
|
||||||
{
|
{
|
||||||
std::wstring tempModelFile = modelFilePath + L".tmp";
|
std::wstring tempModelFile = modelFilePath + L".tmp";
|
||||||
Dictionary state;
|
Dictionary state;
|
||||||
|
state[versionPropertyName] = trainerCheckpointVersion;
|
||||||
state[learnersPropertyName] = learnerState;
|
state[learnersPropertyName] = learnerState;
|
||||||
state[externalStatePropertyName] = externalState;
|
state[externalStatePropertyName] = externalState;
|
||||||
|
state[distributedStatePropertyName] = distributedState;
|
||||||
|
|
||||||
m_combinedTrainingFunction->SaveModel(tempModelFile);
|
m_combinedTrainingFunction->SaveModel(tempModelFile);
|
||||||
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
||||||
|
@ -359,25 +378,57 @@ namespace CNTK
|
||||||
|
|
||||||
Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath));
|
Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath));
|
||||||
|
|
||||||
|
size_t version = 0;
|
||||||
|
|
||||||
|
if (checkpoint.Contains(versionPropertyName))
|
||||||
|
version = checkpoint[versionPropertyName].Value<size_t>();
|
||||||
|
|
||||||
auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
|
auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
|
||||||
auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();
|
auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();
|
||||||
|
|
||||||
|
m_parameterLearners->RestoreFromCheckpoint(learnerState);
|
||||||
|
|
||||||
if (!m_distributed)
|
if (!m_distributed)
|
||||||
{
|
{
|
||||||
m_parameterLearners->RestoreFromCheckpoint(learnerState);
|
|
||||||
return externalState;
|
return externalState;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_parameterLearners->RestoreFromCheckpoint(learnerState);
|
// this ensures that nobody will start writing to the model/checkpoint files, until
|
||||||
|
// everybody is done reading them.
|
||||||
DistributedCommunicatorPtr communicator = MPICommunicator();
|
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||||
communicator->Barrier();
|
communicator->Barrier();
|
||||||
|
|
||||||
auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank);
|
auto mainWorkerId = std::to_wstring(0);
|
||||||
|
auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank);
|
||||||
|
|
||||||
if (externalState.Contains(key))
|
// before version 1, there was no distributed state per se. Instead, the external state
|
||||||
|
// contained a dictionary of worker-specific external states.
|
||||||
|
if (version == 0)
|
||||||
|
{
|
||||||
|
auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId;
|
||||||
return externalState[key].Value<Dictionary>();
|
return externalState[key].Value<Dictionary>();
|
||||||
else
|
}
|
||||||
return externalState[std::to_wstring(0)].Value<Dictionary>();
|
|
||||||
|
Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>();
|
||||||
|
|
||||||
|
if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId))
|
||||||
|
{
|
||||||
|
return externalState;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the checkpoint contains internal state for this worker.
|
||||||
|
Dictionary localState = distributedState[localWorkerId].Value<Dictionary>();
|
||||||
|
|
||||||
|
auto internalState = localState[internalWorkerStateKey].Value<Dictionary>();
|
||||||
|
auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction);
|
||||||
|
if (compositeFunction == nullptr)
|
||||||
|
RuntimeError("Combined training function is not a CompositeFunction.");
|
||||||
|
|
||||||
|
// this assumes the compositeFunction (restored form a checkpoint made by the main node) and
|
||||||
|
// the internal worker state both have identical UIDs.
|
||||||
|
compositeFunction->SetInternalState(internalState);
|
||||||
|
|
||||||
|
return localState[externalWorkerStateKey].Value<Dictionary>();
|
||||||
}
|
}
|
||||||
|
|
||||||
double Trainer::PreviousMinibatchLossAverage() const
|
double Trainer::PreviousMinibatchLossAverage() const
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include "stdafx.h"
|
#include "stdafx.h"
|
||||||
#include "CNTKLibrary.h"
|
#include "CNTKLibrary.h"
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include "Utils.h"
|
|
||||||
|
|
||||||
namespace CNTK
|
namespace CNTK
|
||||||
{
|
{
|
||||||
|
|
|
@ -11,10 +11,10 @@
|
||||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
CPURNGHandle::CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
|
CPURNGHandle::CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
|
||||||
: RNGHandle(deviceId)
|
: RNGHandle(deviceId),
|
||||||
|
m_generator(seed)
|
||||||
{
|
{
|
||||||
m_generator.reset(new std::mt19937_64(seed));
|
m_generator.discard(offset);
|
||||||
m_generator->discard(offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}}}
|
}}}
|
||||||
|
|
|
@ -20,12 +20,11 @@ public:
|
||||||
|
|
||||||
std::mt19937_64& Generator()
|
std::mt19937_64& Generator()
|
||||||
{
|
{
|
||||||
return *m_generator;
|
return m_generator;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<std::mt19937_64> m_generator;
|
std::mt19937_64 m_generator;
|
||||||
// TODO: why is this a ptr?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}}}
|
}}}
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "Common.h"
|
#include "Common.h"
|
||||||
|
|
||||||
using namespace CNTK;
|
using namespace CNTK;
|
||||||
|
using namespace std;
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
|
|
||||||
extern bool Is1bitSGDAvailable();
|
extern bool Is1bitSGDAvailable();
|
||||||
|
@ -35,14 +36,23 @@ namespace
|
||||||
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
|
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
|
||||||
const size_t totalNumberOfSamples = numSamplesPerSweep * numSweepsToTrainWith;
|
const size_t totalNumberOfSamples = numSamplesPerSweep * numSweepsToTrainWith;
|
||||||
|
|
||||||
|
|
||||||
|
const std::wstring g_attributeNameRngSeed = L"rngSeed";
|
||||||
|
const std::wstring g_attributeNameRngOffset = L"rngOffset";
|
||||||
|
|
||||||
|
inline MinibatchSourcePtr GetMinibatchSource(const FeedForwardClassifier& classifier)
|
||||||
|
{
|
||||||
|
return TextFormatMinibatchSource(g_inputFile,
|
||||||
|
{ { g_featureStreamName, classifier.inputDim },
|
||||||
|
{ g_labelsStreamName, classifier.ouputDim } },
|
||||||
|
totalNumberOfSamples, true);
|
||||||
|
}
|
||||||
|
|
||||||
void LoopBasedOnSamples(const std::wstring& name, const DeviceDescriptor& device, std::function<DistributedLearnerPtr(LearnerPtr)> factory, const FeedForwardClassifier& classifier)
|
void LoopBasedOnSamples(const std::wstring& name, const DeviceDescriptor& device, std::function<DistributedLearnerPtr(LearnerPtr)> factory, const FeedForwardClassifier& classifier)
|
||||||
{
|
{
|
||||||
printf("Training loop thru samples with %ls.\n", name.c_str());
|
printf("Training loop thru samples with %ls.\n", name.c_str());
|
||||||
|
|
||||||
auto minibatchSource = TextFormatMinibatchSource(g_inputFile,
|
auto minibatchSource = GetMinibatchSource(classifier);
|
||||||
{ { g_featureStreamName, classifier.inputDim }, { g_labelsStreamName, classifier.ouputDim } },
|
|
||||||
totalNumberOfSamples,
|
|
||||||
true);
|
|
||||||
|
|
||||||
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
|
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
|
||||||
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
|
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
|
||||||
|
@ -135,3 +145,92 @@ void TestFrameMode()
|
||||||
}
|
}
|
||||||
sync->Barrier();
|
sync->Barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void TestDistributedCheckpointing()
|
||||||
|
{
|
||||||
|
std::vector<DeviceDescriptor> devices;
|
||||||
|
if (ShouldRunOnCpu())
|
||||||
|
devices.push_back(DeviceDescriptor::CPUDevice());
|
||||||
|
if (ShouldRunOnGpu())
|
||||||
|
devices.push_back(DeviceDescriptor::GPUDevice(0));
|
||||||
|
|
||||||
|
auto sync = MPICommunicator();
|
||||||
|
|
||||||
|
auto numWorkers = sync->Workers().size();
|
||||||
|
auto workerRank = sync->CurrentWorker().m_globalRank;
|
||||||
|
|
||||||
|
for (auto device : devices)
|
||||||
|
{
|
||||||
|
|
||||||
|
auto ff = BuildFeedForwardClassifier(device);
|
||||||
|
ff.output = Dropout(ff.output, 0.5);
|
||||||
|
ff.trainingLoss = CNTK::CrossEntropyWithSoftmax(ff.output, ff.labels, L"lossFunction");
|
||||||
|
ff.prediction = CNTK::ClassificationError(ff.output, ff.labels, L"classificationError");
|
||||||
|
|
||||||
|
{
|
||||||
|
auto& attributes = ff.output->RootFunction()->Attributes();
|
||||||
|
size_t seed = attributes[g_attributeNameRngSeed].Value<size_t>();
|
||||||
|
// Check that (1) the seed is in the attributes dictionary and
|
||||||
|
// (2) the auto-generated seed value reflects the workerRank.
|
||||||
|
if (numWorkers > 1 && seed % numWorkers != workerRank)
|
||||||
|
ReportFailure("Unexpected seed value");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto learner = SGDLearner(ff.output->Parameters(), LearningRatePerSampleSchedule(0.02));
|
||||||
|
auto distributedLearner = CreateDataParallelDistributedLearner(MPICommunicator(), learner, 0);
|
||||||
|
auto trainer = CreateTrainer(ff.output, ff.trainingLoss, ff.prediction, { distributedLearner });
|
||||||
|
|
||||||
|
auto minibatchSource = GetMinibatchSource(ff);
|
||||||
|
|
||||||
|
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
|
||||||
|
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
|
||||||
|
|
||||||
|
vector<double> expectedLoss(100);
|
||||||
|
for (int i = 0; i < 100; i++)
|
||||||
|
{
|
||||||
|
if (i % 10 == 0)
|
||||||
|
{
|
||||||
|
auto checkpoint = minibatchSource->GetCheckpointState();
|
||||||
|
trainer->SaveCheckpoint(L"distributed_checkpoint_test." + to_wstring(i), checkpoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||||
|
unordered_map<Variable, MinibatchData> minibatch = { { ff.features, minibatchData[featureStreamInfo] },{ ff.labels, minibatchData[labelStreamInfo] } };
|
||||||
|
|
||||||
|
trainer->TrainMinibatch(minibatch, device);
|
||||||
|
expectedLoss[i] = trainer->PreviousMinibatchLossAverage();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; i++)
|
||||||
|
{
|
||||||
|
if (i % 10 == 0)
|
||||||
|
{
|
||||||
|
auto checkpoint = trainer->RestoreFromCheckpoint(L"distributed_checkpoint_test." + to_wstring(i));
|
||||||
|
minibatchSource->RestoreFromCheckpoint(checkpoint);
|
||||||
|
|
||||||
|
auto& attributes = ff.output->RootFunction()->Attributes();
|
||||||
|
size_t seed = attributes[g_attributeNameRngSeed].Value<size_t>();
|
||||||
|
size_t offset = attributes[g_attributeNameRngOffset].Value<size_t>();
|
||||||
|
|
||||||
|
// Check that the worker-specific seed value was properly restored from the checkpoint.
|
||||||
|
if (numWorkers > 1 && seed % numWorkers != workerRank)
|
||||||
|
ReportFailure("Unexpected seed value");
|
||||||
|
// Check the offset and verify that it changes depending on the number of processed minibatches.
|
||||||
|
if (offset != i * minibatchSize * ff.inputDim)
|
||||||
|
ReportFailure("Unexpected seed value");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||||
|
unordered_map<Variable, MinibatchData> minibatch = { { ff.features, minibatchData[featureStreamInfo] },{ ff.labels, minibatchData[labelStreamInfo] } };
|
||||||
|
|
||||||
|
trainer->TrainMinibatch(minibatch, device);
|
||||||
|
auto loss = trainer->PreviousMinibatchLossAverage();
|
||||||
|
|
||||||
|
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sync->Barrier();
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ void MNISTClassifierTests();
|
||||||
void TrainSequenceToSequenceTranslator();
|
void TrainSequenceToSequenceTranslator();
|
||||||
void TrainTruncatedLSTMAcousticModelClassifier();
|
void TrainTruncatedLSTMAcousticModelClassifier();
|
||||||
void TestFrameMode();
|
void TestFrameMode();
|
||||||
|
void TestDistributedCheckpointing();
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
int main(int argc, char *argv[])
|
||||||
{
|
{
|
||||||
|
@ -58,6 +59,8 @@ int main(int argc, char *argv[])
|
||||||
|
|
||||||
TestFrameMode();
|
TestFrameMode();
|
||||||
|
|
||||||
|
TestDistributedCheckpointing();
|
||||||
|
|
||||||
std::string testsPassedMsg = "\nCNTKv2Library-Distribution tests: Passed\n";
|
std::string testsPassedMsg = "\nCNTKv2Library-Distribution tests: Passed\n";
|
||||||
|
|
||||||
printf("%s", testsPassedMsg.c_str());
|
printf("%s", testsPassedMsg.c_str());
|
||||||
|
|
|
@ -438,6 +438,10 @@ Test module "V2LibraryTests" has passed with:
|
||||||
|
|
||||||
Test case "SerializationSuite/CheckpointingWithStatefulNodesInGPU" has passed
|
Test case "SerializationSuite/CheckpointingWithStatefulNodesInGPU" has passed
|
||||||
|
|
||||||
|
Test case "SerializationSuite/CheckpointingWithStatefulNodesAndExplicitSeedsOnCPU" has passed
|
||||||
|
|
||||||
|
Test case "SerializationSuite/CheckpointingWithStatefulNodesAndExplicitSeedsOnGPU" has passed
|
||||||
|
|
||||||
Test suite "FeedForwardSuite" has passed with:
|
Test suite "FeedForwardSuite" has passed with:
|
||||||
6 test cases out of 6 passed
|
6 test cases out of 6 passed
|
||||||
8 assertions out of 8 passed
|
8 assertions out of 8 passed
|
||||||
|
|
|
@ -815,6 +815,74 @@ void TestCheckpointingWithStatefulNodes(const DeviceDescriptor& device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void TestCheckpointingWithStatefulNodesAndExplicitSeeds(const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
auto featureStreamName = L"features";
|
||||||
|
auto labelsStreamName = L"labels";
|
||||||
|
|
||||||
|
size_t inputDim = 784;
|
||||||
|
size_t numOutputClasses = 10;
|
||||||
|
auto features = InputVariable({ inputDim }, false /*isSparse*/, DataType::Float, featureStreamName);
|
||||||
|
auto labels = InputVariable({ numOutputClasses }, DataType::Float, labelsStreamName);
|
||||||
|
|
||||||
|
auto net1 = BuildFFClassifierNet(features, numOutputClasses, device, 1);
|
||||||
|
auto net2 = net1->Clone(ParameterCloningMethod::Clone, { { features , features } });
|
||||||
|
auto net3 = net1->Clone(ParameterCloningMethod::Clone, { { features , features } });
|
||||||
|
|
||||||
|
auto trainer1 = BuildTrainer(Dropout(net1, 0.5, 123), labels);
|
||||||
|
auto trainer2 = BuildTrainer(Dropout(net2, 0.5, 123), labels);
|
||||||
|
auto trainer3 = BuildTrainer(Dropout(net3, 0.5, 321), labels);
|
||||||
|
|
||||||
|
const size_t minibatchSize = 50;
|
||||||
|
const size_t maxSamples = 150;
|
||||||
|
auto minibatchSource = TextFormatMinibatchSource(L"Train-28x28_cntk_text.txt", { { featureStreamName, inputDim },{ labelsStreamName, numOutputClasses } }, 2 * maxSamples, false);
|
||||||
|
|
||||||
|
auto featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||||
|
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||||
|
|
||||||
|
for (int i = 0; i < maxSamples; i+=minibatchSize)
|
||||||
|
{
|
||||||
|
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||||
|
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
|
||||||
|
|
||||||
|
trainer1->TrainMinibatch(minibatch, device);
|
||||||
|
trainer2->TrainMinibatch(minibatch, device);
|
||||||
|
trainer3->TrainMinibatch(minibatch, device);
|
||||||
|
auto loss1 = trainer1->PreviousMinibatchLossAverage();
|
||||||
|
auto loss2 = trainer2->PreviousMinibatchLossAverage();
|
||||||
|
auto loss3 = trainer3->PreviousMinibatchLossAverage();
|
||||||
|
FloatingPointCompare(loss1, loss2, "Training loss does not match expectation");
|
||||||
|
BOOST_TEST((abs(loss1 - loss2) <= abs(loss2 - loss3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer1->SaveCheckpoint(L"seeded_stateful_nodes.model");
|
||||||
|
auto state = minibatchSource->GetCheckpointState();
|
||||||
|
|
||||||
|
vector<double> expectedLoss;
|
||||||
|
for (int i = 0; i < maxSamples; i += minibatchSize)
|
||||||
|
{
|
||||||
|
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||||
|
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
|
||||||
|
|
||||||
|
trainer1->TrainMinibatch(minibatch, device);
|
||||||
|
expectedLoss.push_back(trainer1->PreviousMinibatchLossAverage());
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer1->RestoreFromCheckpoint(L"seeded_stateful_nodes.model");
|
||||||
|
minibatchSource->RestoreFromCheckpoint(state);
|
||||||
|
|
||||||
|
for (int i = 0; i*minibatchSize < maxSamples; i++)
|
||||||
|
{
|
||||||
|
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||||
|
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
|
||||||
|
|
||||||
|
trainer1->TrainMinibatch(minibatch, device);
|
||||||
|
double loss = trainer1->PreviousMinibatchLossAverage();
|
||||||
|
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TestLoadingModelFromMemoryBuffer()
|
void TestLoadingModelFromMemoryBuffer()
|
||||||
{
|
{
|
||||||
ifstream modelFileStream("batch.norm.no.sample.count.v2.bin", ifstream::binary);
|
ifstream modelFileStream("batch.norm.no.sample.count.v2.bin", ifstream::binary);
|
||||||
|
@ -992,6 +1060,18 @@ BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesInGPU)
|
||||||
TestCheckpointingWithStatefulNodes(DeviceDescriptor::GPUDevice(0));
|
TestCheckpointingWithStatefulNodes(DeviceDescriptor::GPUDevice(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesAndExplicitSeedsOnCPU)
|
||||||
|
{
|
||||||
|
TestCheckpointingWithStatefulNodesAndExplicitSeeds(DeviceDescriptor::CPUDevice());
|
||||||
|
}
|
||||||
|
|
||||||
|
BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesAndExplicitSeedsOnGPU)
|
||||||
|
{
|
||||||
|
if (ShouldRunOnGpu())
|
||||||
|
TestCheckpointingWithStatefulNodesAndExplicitSeeds(DeviceDescriptor::GPUDevice(0));
|
||||||
|
}
|
||||||
|
|
||||||
BOOST_AUTO_TEST_SUITE_END()
|
BOOST_AUTO_TEST_SUITE_END()
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
|
|
@ -283,7 +283,7 @@
|
||||||
" with C.layers.default_options(initial_state = 0.1):\n",
|
" with C.layers.default_options(initial_state = 0.1):\n",
|
||||||
" m = C.layers.Recurrence(C.layers.LSTM(N))(x)\n",
|
" m = C.layers.Recurrence(C.layers.LSTM(N))(x)\n",
|
||||||
" m = C.ops.sequence.last(m)\n",
|
" m = C.ops.sequence.last(m)\n",
|
||||||
" m = C.layers.Dropout(0.2)(m)\n",
|
" m = C.layers.Dropout(0.2, seed=1)(m)\n",
|
||||||
" m = cntk.layers.Dense(1)(m)\n",
|
" m = cntk.layers.Dense(1)(m)\n",
|
||||||
" return m"
|
" return m"
|
||||||
]
|
]
|
||||||
|
|
|
@ -573,6 +573,9 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
%ignore CNTK::Dictionary::Keys;
|
||||||
|
|
||||||
%extend CNTK::Dictionary {
|
%extend CNTK::Dictionary {
|
||||||
PyObject* __getitem__(const wchar_t* key) {
|
PyObject* __getitem__(const wchar_t* key) {
|
||||||
PyObject *DictionaryValueToPy(const CNTK::DictionaryValue&);
|
PyObject *DictionaryValueToPy(const CNTK::DictionaryValue&);
|
||||||
|
|
|
@ -14,6 +14,7 @@ from ..variables import Variable, Record, Constant
|
||||||
from ..ops import parameter, input, placeholder, combine
|
from ..ops import parameter, input, placeholder, combine
|
||||||
from ..ops import times, element_times, convolution, convolution_transpose, pooling, unpooling, batch_normalization, dropout, splice, reshape, sequence, softmax, tanh, reduce_sum, reduce_mean, sqrt
|
from ..ops import times, element_times, convolution, convolution_transpose, pooling, unpooling, batch_normalization, dropout, splice, reshape, sequence, softmax, tanh, reduce_sum, reduce_mean, sqrt
|
||||||
from cntk.internal import _as_tuple
|
from cntk.internal import _as_tuple
|
||||||
|
from cntk.cntk_py import sentinel_value_for_auto_select_random_seed as SentinelValueForAutoSelectRandomSeed
|
||||||
from .blocks import *
|
from .blocks import *
|
||||||
from .higher_order_layers import *
|
from .higher_order_layers import *
|
||||||
from .blocks import _initializer_for, _get_initial_state_or_default, _INFERRED # helpers
|
from .blocks import _initializer_for, _get_initial_state_or_default, _INFERRED # helpers
|
||||||
|
@ -1023,7 +1024,10 @@ def MaxUnpooling(filter_shape, # shape of receptive field, e.g. (3,3)
|
||||||
|
|
||||||
|
|
||||||
# TODO: should the rate(s) be default_options?
|
# TODO: should the rate(s) be default_options?
|
||||||
def Dropout(dropout_rate=None, keep_prob=None, name=''):
|
def Dropout(dropout_rate=None,
|
||||||
|
keep_prob=None,
|
||||||
|
seed = SentinelValueForAutoSelectRandomSeed,
|
||||||
|
name=''):
|
||||||
'''
|
'''
|
||||||
Layer factory function to create a drop-out layer.
|
Layer factory function to create a drop-out layer.
|
||||||
|
|
||||||
|
@ -1043,6 +1047,7 @@ def Dropout(dropout_rate=None, keep_prob=None, name=''):
|
||||||
Args:
|
Args:
|
||||||
dropout_rate (float): probability of dropping out an element, mutually exclusive with ``keep_prob``
|
dropout_rate (float): probability of dropping out an element, mutually exclusive with ``keep_prob``
|
||||||
keep_prob (float): probability of keeping an element, mutually exclusive with ``dropout_rate``
|
keep_prob (float): probability of keeping an element, mutually exclusive with ``dropout_rate``
|
||||||
|
seed (int): random seed.
|
||||||
name (str, defaults to ''): the name of the function instance in the network
|
name (str, defaults to ''): the name of the function instance in the network
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1057,7 +1062,7 @@ def Dropout(dropout_rate=None, keep_prob=None, name=''):
|
||||||
dropout_rate = 1-keep_prob
|
dropout_rate = 1-keep_prob
|
||||||
@BlockFunction('Dropout', name)
|
@BlockFunction('Dropout', name)
|
||||||
def dropout_f(x):
|
def dropout_f(x):
|
||||||
return dropout(x, dropout_rate=dropout_rate)
|
return dropout(x, dropout_rate=dropout_rate, seed=seed)
|
||||||
return dropout_f
|
return dropout_f
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from cntk.internal import sanitize_input, sanitize_shape, sanitize_axis, sanitiz
|
||||||
from cntk.internal.utils import get_data_type
|
from cntk.internal.utils import get_data_type
|
||||||
from ..axis import Axis
|
from ..axis import Axis
|
||||||
from .. import cntk_py
|
from .. import cntk_py
|
||||||
|
from ..cntk_py import sentinel_value_for_auto_select_random_seed as SentinelValueForAutoSelectRandomSeed
|
||||||
from ..default_options import get_default_override, default_override_or
|
from ..default_options import get_default_override, default_override_or
|
||||||
|
|
||||||
TIMES_NO_INFERRED_INPUT_RANK = cntk_py.TimesNoInferredInputRank
|
TIMES_NO_INFERRED_INPUT_RANK = cntk_py.TimesNoInferredInputRank
|
||||||
|
@ -2383,7 +2384,12 @@ def argmin(x, axis=None, name=''):
|
||||||
#######################################################################
|
#######################################################################
|
||||||
|
|
||||||
@typemap
|
@typemap
|
||||||
def random_sample(weights, num_samples, allow_duplicates, name=''):
|
def random_sample(
|
||||||
|
weights,
|
||||||
|
num_samples,
|
||||||
|
allow_duplicates,
|
||||||
|
seed = SentinelValueForAutoSelectRandomSeed,
|
||||||
|
name=''):
|
||||||
'''
|
'''
|
||||||
Estimates inclusion frequencies for random sampling with or without
|
Estimates inclusion frequencies for random sampling with or without
|
||||||
replacement.
|
replacement.
|
||||||
|
@ -2404,6 +2410,8 @@ def random_sample(weights, num_samples, allow_duplicates, name=''):
|
||||||
num_samples (int): number of expected samples
|
num_samples (int): number of expected samples
|
||||||
allow_duplicates (bool): If sampling is done
|
allow_duplicates (bool): If sampling is done
|
||||||
with replacement (`True`) or without (`False`).
|
with replacement (`True`) or without (`False`).
|
||||||
|
seed (int): random seed.
|
||||||
|
name (:class:`str`, optional): the name of the Function instance in the network.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`~cntk.ops.functions.Function`
|
:class:`~cntk.ops.functions.Function`
|
||||||
|
@ -2412,14 +2420,15 @@ def random_sample(weights, num_samples, allow_duplicates, name=''):
|
||||||
from cntk.cntk_py import random_sample
|
from cntk.cntk_py import random_sample
|
||||||
weights = sanitize_input(weights)
|
weights = sanitize_input(weights)
|
||||||
|
|
||||||
return random_sample(weights, num_samples, allow_duplicates, name)
|
return random_sample(weights, num_samples, allow_duplicates, seed, name)
|
||||||
|
|
||||||
|
|
||||||
@typemap
|
@typemap
|
||||||
def random_sample_inclusion_frequency(
|
def random_sample_inclusion_frequency(
|
||||||
weights,
|
weights,
|
||||||
num_samples,
|
num_samples,
|
||||||
allow_duplicates,
|
allow_duplicates,
|
||||||
|
seed = SentinelValueForAutoSelectRandomSeed,
|
||||||
name=''):
|
name=''):
|
||||||
'''
|
'''
|
||||||
For weighted sampling with the specifed sample size (`num_samples`)
|
For weighted sampling with the specifed sample size (`num_samples`)
|
||||||
|
@ -2438,6 +2447,8 @@ def random_sample_inclusion_frequency(
|
||||||
num_samples (int): number of expected samples
|
num_samples (int): number of expected samples
|
||||||
allow_duplicates (bool): If sampling is done
|
allow_duplicates (bool): If sampling is done
|
||||||
with replacement (`True`) or without (`False`).
|
with replacement (`True`) or without (`False`).
|
||||||
|
seed (int): random seed.
|
||||||
|
name (:class:`str`, optional): the name of the Function instance in the network.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
|
@ -2470,11 +2481,12 @@ def random_sample_inclusion_frequency(
|
||||||
weights,
|
weights,
|
||||||
num_samples,
|
num_samples,
|
||||||
allow_duplicates,
|
allow_duplicates,
|
||||||
|
seed,
|
||||||
name)
|
name)
|
||||||
|
|
||||||
|
|
||||||
@typemap
|
@typemap
|
||||||
def dropout(x, dropout_rate=0.0, name=''):
|
def dropout(x, dropout_rate=0.0, seed = SentinelValueForAutoSelectRandomSeed, name=''):
|
||||||
'''
|
'''
|
||||||
Each element of the input is independently set to 0 with probabily ``dropout_rate``
|
Each element of the input is independently set to 0 with probabily ``dropout_rate``
|
||||||
or to 1 / (1 - ``dropout_rate``) times its original value (with probability 1-``dropout_rate``).
|
or to 1 / (1 - ``dropout_rate``) times its original value (with probability 1-``dropout_rate``).
|
||||||
|
@ -2500,6 +2512,7 @@ def dropout(x, dropout_rate=0.0, name=''):
|
||||||
Args:
|
Args:
|
||||||
x: input tensor
|
x: input tensor
|
||||||
dropout_rate (float, [0,1)): probability that an element of ``x`` will be set to zero
|
dropout_rate (float, [0,1)): probability that an element of ``x`` will be set to zero
|
||||||
|
seed (int): random seed.
|
||||||
name (:class:`str`, optional): the name of the Function instance in the network
|
name (:class:`str`, optional): the name of the Function instance in the network
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -2511,7 +2524,7 @@ def dropout(x, dropout_rate=0.0, name=''):
|
||||||
from cntk.cntk_py import dropout
|
from cntk.cntk_py import dropout
|
||||||
x = sanitize_input(x)
|
x = sanitize_input(x)
|
||||||
|
|
||||||
return dropout(x, dropout_rate, name)
|
return dropout(x, dropout_rate, seed, name)
|
||||||
|
|
||||||
##########################################################################
|
##########################################################################
|
||||||
# variables_and_parameters ops
|
# variables_and_parameters ops
|
||||||
|
|
|
@ -186,6 +186,41 @@ def test_op_dropout(shape, dropout_rate, device_id, precision):
|
||||||
assert(abs(resulted_non_zeros - expected_non_zeros) <
|
assert(abs(resulted_non_zeros - expected_non_zeros) <
|
||||||
max_off)
|
max_off)
|
||||||
|
|
||||||
|
def test_op_dropout_with_explicit_seed(device_id, precision):
|
||||||
|
from cntk import combine, dropout, input
|
||||||
|
|
||||||
|
value = np.ones(shape=(10,10), dtype=PRECISION_TO_TYPE[precision])
|
||||||
|
|
||||||
|
a = input(shape=value.shape,
|
||||||
|
dtype=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
|
||||||
|
needs_gradient=True,
|
||||||
|
name='a')
|
||||||
|
|
||||||
|
seed = 123;
|
||||||
|
|
||||||
|
dropout_nodes= [
|
||||||
|
dropout(a, dropout_rate=0.5, seed=seed),
|
||||||
|
dropout(a, dropout_rate=0.5, seed=seed),
|
||||||
|
dropout(a, dropout_rate=0.5, seed=seed+1),
|
||||||
|
dropout(a, dropout_rate=0.5)
|
||||||
|
]
|
||||||
|
|
||||||
|
value.shape = (1, 1) + value.shape
|
||||||
|
forward_input = {a: value}
|
||||||
|
results = []
|
||||||
|
for node in dropout_nodes:
|
||||||
|
forward, backward = cntk_eval(node,
|
||||||
|
forward_input,
|
||||||
|
precision,
|
||||||
|
cntk_device(device_id),
|
||||||
|
backward_pass=True)
|
||||||
|
|
||||||
|
results.append(forward[node.output])
|
||||||
|
|
||||||
|
assert np.allclose(results[0], results[1])
|
||||||
|
assert not np.allclose(results[0], results[2])
|
||||||
|
assert not np.allclose(results[0], results[3])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dropout_rate", [-0.1, 1.0, 100])
|
@pytest.mark.parametrize("dropout_rate", [-0.1, 1.0, 100])
|
||||||
def test_op_dropout_bad_input(dropout_rate):
|
def test_op_dropout_bad_input(dropout_rate):
|
||||||
|
|
|
@ -114,3 +114,18 @@ def test_random_sample_without_replacement(weights, num_samples, expected_count,
|
||||||
denseResult = times(result, identity)
|
denseResult = times(result, identity)
|
||||||
observed_count = np.sum(denseResult.eval(), 0)
|
observed_count = np.sum(denseResult.eval(), 0)
|
||||||
assert np.allclose(observed_count, expected_count, atol=tolerance)
|
assert np.allclose(observed_count, expected_count, atol=tolerance)
|
||||||
|
|
||||||
|
def test_random_sample_with_explicit_seed(device_id, precision):
|
||||||
|
weights = AA([x for x in range(0, 10)], precision)
|
||||||
|
identity = np.identity(weights.size)
|
||||||
|
allow_duplicates = False # sample without replacement
|
||||||
|
num_samples = 5;
|
||||||
|
seed = 123
|
||||||
|
to_dense = lambda x: times(x, identity).eval()
|
||||||
|
result1 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed))
|
||||||
|
result2 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed))
|
||||||
|
result3 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed+1))
|
||||||
|
result4 = to_dense(random_sample(weights, num_samples, allow_duplicates))
|
||||||
|
assert np.allclose(result1, result2)
|
||||||
|
assert not np.allclose(result1, result3)
|
||||||
|
assert not np.allclose(result1, result4)
|
||||||
|
|
|
@ -62,9 +62,9 @@ def test_convolution_transpose_attributes():
|
||||||
|
|
||||||
def test_dropout_attributes():
|
def test_dropout_attributes():
|
||||||
x = C.input( (1, 5, 5) )
|
x = C.input( (1, 5, 5) )
|
||||||
f = C.dropout(x, 0.5)
|
f = C.dropout(x, 0.5, 42)
|
||||||
d = f.root_function.attributes
|
d = f.root_function.attributes
|
||||||
expected = {'dropoutRate': 0.5}
|
expected = {'dropoutRate': 0.5, 'rngSeed' : 42, 'rngOffset' : 0}
|
||||||
_check(expected, d)
|
_check(expected, d)
|
||||||
|
|
||||||
def test_slice_attributes():
|
def test_slice_attributes():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче