Integrate alrezni/v2_dropout into master

This commit is contained in:
Project Philly 2017-04-04 09:16:31 -07:00
Родитель d192653c11 859296557d
Коммит 94993f3c81
26 изменённых файлов: 701 добавлений и 225 удалений

Просмотреть файл

@ -1531,7 +1531,15 @@ namespace CNTK
DictionaryIterator end() const { return m_dictionaryData->end(); } DictionaryIterator end() const { return m_dictionaryData->end(); }
ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); } ConstDictionaryIterator cend() const { return m_dictionaryData->cend(); }
size_t Size() { return m_dictionaryData->size(); } size_t Size() const { return m_dictionaryData->size(); }
std::unordered_set<std::wstring> Keys()
{
std::unordered_set<std::wstring> keys;
for (const auto& kv : *m_dictionaryData)
keys.insert(kv.first);
return keys;
}
friend CNTK_API std::istream& operator>>(std::istream& stream, Dictionary& us); friend CNTK_API std::istream& operator>>(std::istream& stream, Dictionary& us);
friend CNTK_API std::ostream& operator<<(std::ostream& stream, const Dictionary& us); friend CNTK_API std::ostream& operator<<(std::ostream& stream, const Dictionary& us);
@ -3391,20 +3399,17 @@ namespace CNTK
/// ///
/// Create an instance of the random_sample operation on specified sampling weights input vector /// Create an instance of the random_sample operation on specified sampling weights input vector
/// ///
// TODO: The initial random seed should be specifiable CNTK_API FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
CNTK_API FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name /*= L""*/);
/// ///
/// Create an instance of the random_sample_inclusion_frequency operation on specified sampling weights input vector /// Create an instance of the random_sample_inclusion_frequency operation on specified sampling weights input vector
/// ///
// TODO: The initial random seed should be specifiable CNTK_API FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
CNTK_API FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name /*= L""*/);
/// ///
/// Create an instance of the dropout operation on specified tensor input operand /// Create an instance of the dropout operation on specified tensor input operand
/// ///
// TODO: The initial random seed should be specifiable CNTK_API FunctionPtr Dropout(const Variable& operand, double dropoutRate, unsigned long seed = SentinelValueForAutoSelectRandomSeed, const std::wstring& name = L"");
CNTK_API FunctionPtr Dropout(const Variable& operand, double dropoutRate, const std::wstring& name = L"");
/// ///
/// Create an instance of the reshape operation on specified tensor input operand /// Create an instance of the reshape operation on specified tensor input operand
@ -4605,7 +4610,8 @@ namespace CNTK
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice); bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice); bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState); void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState,
const Dictionary& externalState, const Dictionary& distributedState = {});
void UpdateTrainingProgress(size_t numSamples, const ValuePtr& loss, const ValuePtr& evalCriterion, const DeviceDescriptor& computeDevice); void UpdateTrainingProgress(size_t numSamples, const ValuePtr& loss, const ValuePtr& evalCriterion, const DeviceDescriptor& computeDevice);
void AddProgressWriters(const std::vector<ProgressWriterPtr>& progressWriters); void AddProgressWriters(const std::vector<ProgressWriterPtr>& progressWriters);

Просмотреть файл

@ -239,6 +239,8 @@ namespace CNTK
CNTK_API size_t NewUniqueId(); CNTK_API size_t NewUniqueId();
CNTK_API size_t GenerateRandomSeed();
// Internal hooks for testing and higher-level bindings // Internal hooks for testing and higher-level bindings
// These should not be directly called by C++ API users // These should not be directly called by C++ API users
CNTK_API void EnableReversingTensorShapesInErrorMessages(); CNTK_API void EnableReversingTensorShapesInErrorMessages();

Просмотреть файл

@ -8,6 +8,8 @@
#include "stdafx.h" #include "stdafx.h"
#include "CNTKLibrary.h" #include "CNTKLibrary.h"
#include "PrimitiveFunction.h" #include "PrimitiveFunction.h"
#include "Utils.h"
#include "Variable.h"
namespace CNTK namespace CNTK
{ {

Просмотреть файл

@ -193,4 +193,4 @@
<Warning Condition="!$(HasProtobuf)" Text="CNTKv2LibraryDll requires Protocol Buffers to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#protobuf for installation instructions." /> <Warning Condition="!$(HasProtobuf)" Text="CNTKv2LibraryDll requires Protocol Buffers to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#protobuf for installation instructions." />
<Error Condition="!$(HasBoost)" Text="CNTKv2LibraryDll requires the Boost library to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#boost for installation instructions." /> <Error Condition="!$(HasBoost)" Text="CNTKv2LibraryDll requires the Boost library to build. Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#boost for installation instructions." />
</Target> </Target>
</Project> </Project>

Просмотреть файл

@ -28,12 +28,36 @@ namespace CNTK
{ {
namespace Internal namespace Internal
{ {
static std::atomic<unsigned long long> s_nextUniqueId(0); static std::atomic_ullong s_nextUniqueId = ATOMIC_VAR_INIT(0);
size_t NewUniqueId() size_t NewUniqueId()
{ {
return s_nextUniqueId++; return s_nextUniqueId++;
} }
static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0);
// This is used to generate a default seed for stateful nodes (dropout, and both
// flavors of random sample). As a result, in distributed environment, each worker
// ends up having a different seed.
size_t GenerateRandomSeed()
{
static size_t numWorkers = 1, rank = 0;
static bool initialized = false;
if (MPIWrapper::GetTotalNumberOfMPINodes() != 0 && !initialized)
{
DistributedCommunicatorPtr communicator = MPICommunicator();
numWorkers = communicator->Workers().size();
rank = communicator->CurrentWorker().m_globalRank;
if (numWorkers < 1)
numWorkers = 1;
}
initialized = true;
return (numWorkers * s_currentRandomSeed++) + rank;
}
std::atomic<bool> s_reverseTensorShapesInErrorMessages(false); std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
void EnableReversingTensorShapesInErrorMessages() void EnableReversingTensorShapesInErrorMessages()
{ {

Просмотреть файл

@ -46,8 +46,58 @@ namespace CNTK
return dict; return dict;
} }
// Copy the internal state from the network into the function graph,
// specifically from RngUser nodes into the attributes dictionaries of
// the corresponding stateful primitive functions.
void CompositeFunction::UpdateInternalState() const
{
if (!m_computationNetwork)
return;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
auto& outputs = primitiveFunction->RawOutputs();
if (outputs.size() != 1)
LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());
const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
primitiveFunction->SetState(state);
}
}
// Generate a dictionary representing the internal (local) state of the function graph.
Dictionary CompositeFunction::GetInternalState() const
{
UpdateInternalState();
Dictionary stateDictionary;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
}
return stateDictionary;
}
/*virtual*/ Dictionary CompositeFunction::Serialize() const /*virtual*/ Dictionary CompositeFunction::Serialize() const
{ {
UpdateInternalState();
Dictionary dict = SerializeBlockComposite(); Dictionary dict = SerializeBlockComposite();
// Find cycles in the graph and "break" them by inserting placeholders. // Find cycles in the graph and "break" them by inserting placeholders.
@ -129,29 +179,6 @@ namespace CNTK
} }
dict[functionsKey] = std::move(functionDictionaries); dict[functionsKey] = std::move(functionDictionaries);
// Now, collect and store the internal state for all non-pure (stateful) functions in the graph
// (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
Dictionary stateDictionary;
for (const auto& kv : m_variableToNodeMap)
{
if (kv.second->Is<RngUser>() && kv.first.IsOutput())
{
// The RNG state should be associated with the actual function that the computation node
// corresponds to, and not the block primitives that wrap the actual function
auto ownerFunction = kv.first.Owner().get();
if (!ownerFunction->IsBlock())
{
auto rng = kv.second->As<RngUser>();
Dictionary state;
state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
stateDictionary[ownerFunction->Uid()] = state;
}
}
}
dict[stateKey] = std::move(stateDictionary);
return dict; return dict;
} }
@ -217,10 +244,6 @@ namespace CNTK
uidToInputMap[inputVar.Uid()] = inputVar; uidToInputMap[inputVar.Uid()] = inputVar;
} }
Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>(); const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
std::unordered_map<Variable, Variable> allPlaceholderReplacements; std::unordered_map<Variable, Variable> allPlaceholderReplacements;
@ -238,25 +261,6 @@ namespace CNTK
if (opType == PrimitiveOpType::Combine) if (opType == PrimitiveOpType::Combine)
continue; continue;
if (primitiveFunction->IsStateful())
{
if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
auto seed = state[rngSeedKey].Value<size_t>();
auto offset = state[rngOffsetKey].Value<size_t>();
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
}
else if (Internal::GetComputationNetworkTraceLevel() > 0)
{
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}
}
for (const auto& output : root->RawOutputs()) for (const auto& output : root->RawOutputs())
{ {
const auto& it = uidToInputMap.find(output.Uid()); const auto& it = uidToInputMap.find(output.Uid());
@ -276,63 +280,122 @@ namespace CNTK
} }
} }
// starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
// corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
if (version < 3)
RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device); return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
} }
void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions)
{
Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();
for (auto& function : functions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
continue;
if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
// Add key-value pairs expected by the SetState method to the state dictionary.
state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
primitiveFunction->SetState(state);
}
else
{
if (Internal::GetComputationNetworkTraceLevel() > 0) {
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}
// Create state from scratch, so that function attributes contain all the required key-value pairs.
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed();
state[PrimitiveFunction::AttributeNameRngOffset] = 0;
primitiveFunction->SetState(state);
}
}
}
void CompositeFunction::CopyState(const CompositeFunction& source) void CompositeFunction::CopyState(const CompositeFunction& source)
{ {
// Create a map with all non-pure (stateful) functions in the function graph. // Collect a vector of stateful funciton uids using a pre-order traversal of a function graphs.
auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> { auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
std::map<std::wstring, FunctionPtr> functionMap; vector<wstring> uids;
for (auto funcPtr : allPrimitiveFunctions) PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get()); auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if (primitiveFunction->IsStateful()) if (primitiveFunction->IsStateful())
{ {
functionMap[primitiveFunction->Uid()] = funcPtr; uids.push_back(funcPtr->Uid());
} }
} }, true);
return functionMap;
return uids;
}; };
std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions); auto theirUIDs = collectStatefulFunctionUIDs(source);
std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions); auto ourUIDs = collectStatefulFunctionUIDs(*this);
assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size()); if (theirUIDs.size() != ourUIDs.size())
if (statefulFunctionsFrom.size() == 0) CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");
auto state = source.GetInternalState();
if (theirUIDs == ourUIDs)
{ {
// uids are identialy, no need to remap.
SetInternalState(state);
return; return;
} }
// build a map of souce funtion to the destination (this) function UIDs.
map<wstring, wstring> uidMap;
for (auto i = 0; i < theirUIDs.size(); i++)
uidMap[theirUIDs[i]] = ourUIDs[i];
Dictionary remappedState;
for (auto& kv : state)
remappedState[uidMap[kv.first]] = kv.second;
// Copy state captured in the attributes dictionaries. SetInternalState(remappedState);
for (const auto& kv : statefulFunctionsFrom)
{
statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
}
UpdateInternalNetworkState();
} }
void CompositeFunction::UpdateInternalNetworkState() void CompositeFunction::SetInternalState(const Dictionary& state)
{ {
if (!m_computationNetwork) if (state.Size() == 0)
{
return; return;
}
for (const auto& function : m_allPrimitiveFunctions) for (const auto& function : m_allPrimitiveFunctions)
{ {
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get()); auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (primitiveFunction->IsStateful()) if (!primitiveFunction->IsStateful())
continue;
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
primitiveFunction->SetState(functionState);
if (!m_computationNetwork)
continue;
auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
// copy the state directly into the network
for (const auto& output : function->RawOutputs())
{ {
for (const auto& output : function->RawOutputs()) auto node = m_variableToNodeMap.at(output);
{ node->As<RngUser>()->SetRngState(seed, offset);
auto node = m_variableToNodeMap.at(output);
auto attributes = function->Attributes();
auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
node->As<RngUser>()->SetRngState(seed, offset);
}
} }
} }
} }
@ -895,16 +958,9 @@ namespace CNTK
if (computationNodePtr->Is<RngUser>()) if (computationNodePtr->Is<RngUser>())
{ {
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed)) auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
{ auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>(); computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
uint64_t offset = 0;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
{
offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
}
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
} }
} }
else else

Просмотреть файл

@ -110,7 +110,7 @@ namespace CNTK
Dictionary SerializeBlockComposite() const; Dictionary SerializeBlockComposite() const;
virtual Dictionary Serialize() const override; virtual Dictionary Serialize() const override;
virtual size_t CurrentVersion() const override { return s_serializationVersion; } virtual size_t CurrentVersion() const override { return s_serializationVersion; }
static FunctionPtr DeserializeBlockComposite(const Dictionary& dict, static FunctionPtr DeserializeBlockComposite(const Dictionary& dict,
@ -238,12 +238,26 @@ namespace CNTK
return inputs; return inputs;
} }
// If the network is already created, copy internal state over from the functions in the graph into the underlying network.
void UpdateInternalNetworkState();
// Copy state info from source function graph into' this' function graph. // Copy the internal state from the network into the function graph.
void UpdateInternalState() const;
// Generate a dictionary representing the internal (local) state of the function graph.
Dictionary GetInternalState() const;
// Update the internal state using the provided dictionary.
// If the network is already created, directly update its state. Otherwise, copy the state from the
// dictionary into the function graph.
void SetInternalState(const Dictionary& state);
// Copy state info from source function graph into 'this' function graph.
// Both graphs must be equivalent.
void CopyState(const CompositeFunction& source); void CopyState(const CompositeFunction& source);
// This function is only needed for backwards compatibility to support deserializing composite funcitions that
// stored the internal state inside a dedicated value in the dictionary.
static void RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> PrimitiveFunctions);
static Variable GetMappingForNoOpOutput(const Variable& variable, bool recursive = false); static Variable GetMappingForNoOpOutput(const Variable& variable, bool recursive = false);
static Variable GetMappingVariable(const Variable& variable, bool recursive = false); static Variable GetMappingVariable(const Variable& variable, bool recursive = false);
@ -328,6 +342,7 @@ namespace CNTK
// Version history: // Version history:
// 1 -- initial version. // 1 -- initial version.
// 2 -- add support for stateful functions (with corresponding nodes inheriting from RngUser). // 2 -- add support for stateful functions (with corresponding nodes inheriting from RngUser).
static const size_t s_serializationVersion = 2; // 3 -- store internal function state directly in the attributes dictionary.
static const size_t s_serializationVersion = 3;
}; };
} }

Просмотреть файл

@ -8,6 +8,7 @@
#include "PrimitiveFunction.h" #include "PrimitiveFunction.h"
#include "CompositeFunction.h" #include "CompositeFunction.h"
#include "BlockFunction.h" #include "BlockFunction.h"
#include "Utils.h"
using namespace Microsoft::MSR::CNTK; using namespace Microsoft::MSR::CNTK;
@ -1037,29 +1038,47 @@ namespace CNTK
LogicError("Slice: Invalid axis argument provided. Slice along the dynamic batch axis is currently unsupported. To slice a sequence along its ordered dynamic axis use Sequence::Slice."); LogicError("Slice: Invalid axis argument provided. Slice along the dynamic batch axis is currently unsupported. To slice a sequence along its ordered dynamic axis use Sequence::Slice.");
} }
FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name) FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed, const std::wstring& name)
{ {
auto additionalProperties = Dictionary(); auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples; additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates; additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
if (seed == SentinelValueForAutoSelectRandomSeed)
seed = Internal::GenerateRandomSeed();
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name); return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name);
} }
FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name) FunctionPtr RandomSampleInclusionFrequency(const Variable& operand, size_t numSamples, bool allowDuplicates, unsigned long seed, const std::wstring& name)
{ {
auto additionalProperties = Dictionary(); auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples; additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates; additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
if (seed == SentinelValueForAutoSelectRandomSeed)
seed = Internal::GenerateRandomSeed();
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name); return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name);
} }
FunctionPtr Dropout(const Variable& operand, double dropoutRate, const std::wstring& name) FunctionPtr Dropout(const Variable& operand, double dropoutRate, unsigned long seed, const std::wstring& name)
{ {
auto additionalProperties = Dictionary(); auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate; additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate;
if (seed == SentinelValueForAutoSelectRandomSeed)
seed = Internal::GenerateRandomSeed();
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
additionalProperties[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name); return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name);
} }

Просмотреть файл

@ -18,6 +18,9 @@
#include "BlockFunction.h" #include "BlockFunction.h"
#include "CompositeFunction.h" #include "CompositeFunction.h"
#include "SpecialPurposeNodes.h" #include "SpecialPurposeNodes.h"
#include "ConvolveGeometry.h"
#include "ConvolutionalNodes.h"
#include "Variable.h"
using namespace Microsoft::MSR::CNTK; using namespace Microsoft::MSR::CNTK;
@ -37,7 +40,7 @@ namespace CNTK
// Names of the various attributes of CNTK primitive Functions // Names of the various attributes of CNTK primitive Functions
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxisVec = L"axisVec"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxisVec = L"axisVec";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates";
@ -976,4 +979,118 @@ namespace CNTK
return std::shared_ptr<PrimitiveFunction>(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid), return std::shared_ptr<PrimitiveFunction>(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid),
[](PrimitiveFunction* ptr) { delete ptr; }); [](PrimitiveFunction* ptr) { delete ptr; });
} }
static const vector<wstring> s_stateAttributes = { PrimitiveFunction::AttributeNameRngSeed, PrimitiveFunction::AttributeNameRngOffset };
Dictionary PrimitiveFunction::GetState() const
{
if (!IsStateful())
LogicError("Function '%S' is not stateful.", AsString().c_str());
Dictionary state;
for (auto& key : s_stateAttributes)
{
state[key] = m_attributes[key];
}
return state;
}
void PrimitiveFunction::SetState(const Dictionary& state)
{
if (!IsStateful())
LogicError("Function '%S' is not stateful.", AsString().c_str());
for (auto& key : s_stateAttributes)
{
m_attributes[key] = state[key];
}
}
/*static*/ void PrimitiveFunction::FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from/* = NDShape()*/)
{
auto dims = shape.Dimensions();
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(filterRank, inputRank, dims, deflt, from.Dimensions());
shape = NDShape(dims);
}
NDShape PrimitiveFunction::ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
bool transpose, bool inferDimensions, bool ceilOutputDim/* = false*/) const
{
if (inferDimensions)
{
size_t inputRank = operandShape.Rank();
// Unknown kernel shape valid only for pooling, however, the shape should have expanded before
// this call.
if (kernelShape == NDShape::Unknown)
RuntimeError("Convolution: Kernel shape can't be Unknown.");
// infer reduction dimensions if not given
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
size_t filterRank = kernelShape.Rank();
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
// picking the corresponding operand shape dimensionality
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
{
filterRank--;
kernelShape = kernelShape.SubShape(0, filterRank);
}
NDShape fromShape;
if (op == PrimitiveOpType::Convolution)
fromShape = operandShape;
size_t fillRank = (!transpose) ? filterRank : filterRank - 1;
FixNDShape(fillRank, inputRank, kernelShape, 1, fromShape); // convolve over red dim; pool over 1
FixNDShape(fillRank, inputRank, strides, 1, fromShape); // stride for reduction dims is red dim or 1
FixNDShape(fillRank, inputRank, lowerPad, 0);
FixNDShape(fillRank, inputRank, upperPad, 0);
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, sharing, true);
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, autoPad, false); // no padding for reduction dims
}
decltype(&Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape) computeOutputShapeFunc;
if (!transpose)
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape;
else
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutputDim));
}
/*static*/ bool PrimitiveFunction::UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes)
{
bool anyParameterOperandDimsInferred = false;
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
{
operand.m_dataFields->m_shape = newOperandShape;
return true;
}
return false;
};
for (auto& newOperandShapePair : newOperandShapes)
anyParameterOperandDimsInferred = updateOperandShapeFunc(newOperandShapePair.first, newOperandShapePair.second) || anyParameterOperandDimsInferred;
return anyParameterOperandDimsInferred;
}
NDShape PrimitiveFunction::NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const
{
assert(operands.size() > 1);
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
Variable dummyOutputVariable = PlaceholderVariable(NDShape());
for (auto& operand : operands)
dummyOutputVariable.m_dataFields->m_shape = BinaryElementwiseOpOutputShape(op, dummyOutputVariable, operand, broadcastAllowed, inferInputDimensions);
return dummyOutputVariable.Shape();
}
} }

Просмотреть файл

@ -8,10 +8,6 @@
#include "stdafx.h" #include "stdafx.h"
#include "CNTKLibrary.h" #include "CNTKLibrary.h"
#include "PrimitiveOpType.h" #include "PrimitiveOpType.h"
#include "Utils.h"
#include "ConvolveGeometry.h"
#include "ConvolutionalNodes.h"
#include "Variable.h"
namespace std namespace std
{ {
@ -301,6 +297,10 @@ namespace CNTK
(OpType() == PrimitiveOpType::RandomSampleInclusionFrequency); (OpType() == PrimitiveOpType::RandomSampleInclusionFrequency);
} }
Dictionary GetState() const;
void SetState(const Dictionary& state);
private: private:
// The following helper functions are used to determine the output shape for different // The following helper functions are used to determine the output shape for different
@ -425,24 +425,7 @@ namespace CNTK
} }
// Returns a boolean indicating if any operand shape was updated // Returns a boolean indicating if any operand shape was updated
static bool UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes) static bool UpdateOperandShapes(std::vector<std::pair<Variable, NDShape>>& newOperandShapes);
{
bool anyParameterOperandDimsInferred = false;
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
{
operand.m_dataFields->m_shape = newOperandShape;
return true;
}
return false;
};
for (auto& newOperandShapePair : newOperandShapes)
anyParameterOperandDimsInferred = updateOperandShapeFunc(newOperandShapePair.first, newOperandShapePair.second) || anyParameterOperandDimsInferred;
return anyParameterOperandDimsInferred;
}
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified // Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
/*static*/ NDShape BinaryElementwiseOpOutputShape(PrimitiveOpType op, Variable& leftOperand, Variable& rightOperand, bool broadcastAllowed, bool inferInputDimensions) const /*static*/ NDShape BinaryElementwiseOpOutputShape(PrimitiveOpType op, Variable& leftOperand, Variable& rightOperand, bool broadcastAllowed, bool inferInputDimensions) const
@ -493,6 +476,10 @@ namespace CNTK
} }
} }
UNUSED(broadcastAllowed);
// BUGBUG: if (broadcastAllowed) is missing here?
// Broadcast in remaining axes // Broadcast in remaining axes
for (size_t i = shapeWithSmallerNumAxes.Rank(); i < numOutputAxes; ++i) for (size_t i = shapeWithSmallerNumAxes.Rank(); i < numOutputAxes; ++i)
outputDims[i] = shapeWithLargerNumAxes[i]; outputDims[i] = shapeWithLargerNumAxes[i];
@ -507,17 +494,7 @@ namespace CNTK
return NDShape(std::move(outputDims)); return NDShape(std::move(outputDims));
} }
/*static*/ NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const /*static*/ NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, std::vector<Variable>& operands, bool broadcastAllowed, bool inferInputDimensions) const;
{
assert(operands.size() > 1);
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
Variable dummyOutputVariable = PlaceholderVariable(NDShape());
for (auto& operand : operands)
dummyOutputVariable.m_dataFields->m_shape = BinaryElementwiseOpOutputShape(op, dummyOutputVariable, operand, broadcastAllowed, inferInputDimensions);
return dummyOutputVariable.Shape();
}
// Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified // Returns a pair comprising of the output shape and boolean indicating if any input operand shape was modified
/*static*/ NDShape TimesOpOutputShape(Variable& leftOperand, Variable& rightOperand, size_t outputRank, int inferInputRankToMap, bool inferInputDimensions) const /*static*/ NDShape TimesOpOutputShape(Variable& leftOperand, Variable& rightOperand, size_t outputRank, int inferInputRankToMap, bool inferInputDimensions) const
@ -643,61 +620,11 @@ namespace CNTK
return NDShape(std::move(outputDims)); return NDShape(std::move(outputDims));
} }
static void FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from = NDShape()) static void FixNDShape(size_t filterRank, size_t inputRank, NDShape& shape, size_t deflt, const NDShape& from = NDShape());
{
auto dims = shape.Dimensions();
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(filterRank, inputRank, dims, deflt, from.Dimensions());
shape = NDShape(dims);
}
/*static*/ NDShape ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides, /*static*/ NDShape ConvolutionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, NDShape& kernelShape, NDShape& outputMapCount, NDShape& strides,
std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad, std::vector<bool>& sharing, std::vector<bool>& autoPad, NDShape& lowerPad, NDShape& upperPad,
bool transpose, bool inferDimensions, bool ceilOutputDim = false) const bool transpose, bool inferDimensions, bool ceilOutputDim = false) const;
{
if (inferDimensions)
{
size_t inputRank = operandShape.Rank();
// Unknown kernel shape valid only for pooling, however, the shape should have expanded before
// this call.
if (kernelShape == NDShape::Unknown)
RuntimeError("Convolution: Kernel shape can't be Unknown.");
// infer reduction dimensions if not given
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
size_t filterRank = kernelShape.Rank();
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
// picking the corresponding operand shape dimensionality
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
{
filterRank--;
kernelShape = kernelShape.SubShape(0, filterRank);
}
NDShape fromShape;
if (op == PrimitiveOpType::Convolution)
fromShape = operandShape;
size_t fillRank = (!transpose)? filterRank : filterRank - 1;
FixNDShape(fillRank, inputRank, kernelShape, 1, fromShape); // convolve over red dim; pool over 1
FixNDShape(fillRank, inputRank, strides, 1, fromShape); // stride for reduction dims is red dim or 1
FixNDShape(fillRank, inputRank, lowerPad, 0);
FixNDShape(fillRank, inputRank, upperPad, 0);
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, sharing, true);
Microsoft::MSR::CNTK::ConvolutionNodeBase<float>::FixVectorShape(fillRank, inputRank, autoPad, false); // no padding for reduction dims
}
decltype(&Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape) computeOutputShapeFunc;
if (!transpose)
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeOutputShape;
else
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutputDim));
}
/*static*/ NDShape BatchNormalizationOutputShape(std::vector<Variable>& operands, bool spatial, bool inferDimensions) const /*static*/ NDShape BatchNormalizationOutputShape(std::vector<Variable>& operands, bool spatial, bool inferDimensions) const
{ {

Просмотреть файл

@ -41,6 +41,8 @@ namespace CNTK
const std::wstring blockFunctionOpNameKey = L"block_function_op_name"; const std::wstring blockFunctionOpNameKey = L"block_function_op_name";
const std::wstring blockFunctionCompositeArgumentsMapKeysKey = L"block_function_composite_arguments_map_keys"; const std::wstring blockFunctionCompositeArgumentsMapKeysKey = L"block_function_composite_arguments_map_keys";
const std::wstring blockFunctionCompositeArgumentsMapValuesKey = L"block_function_composite_arguments_map_values"; const std::wstring blockFunctionCompositeArgumentsMapValuesKey = L"block_function_composite_arguments_map_values";
const std::wstring internalWorkerStateKey = L"internal_worker_state";
const std::wstring externalWorkerStateKey = L"external_worker_state";
template <typename T> template <typename T>
inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion) inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion)

Просмотреть файл

@ -8,11 +8,21 @@
#include "Utils.h" #include "Utils.h"
#include "Learner.h" #include "Learner.h"
#include "PerformanceProfiler.h" #include "PerformanceProfiler.h"
#include "CompositeFunction.h"
#include "Serialization.h"
namespace namespace
{ {
const std::wstring versionPropertyName = L"Version";
const std::wstring learnersPropertyName = L"Learners"; const std::wstring learnersPropertyName = L"Learners";
const std::wstring externalStatePropertyName = L"ExternalState"; const std::wstring externalStatePropertyName = L"ExternalState";
const std::wstring distributedStatePropertyName = L"DistributedState";
// Version history:
// 0 -- a version number before the versioning was introduced for the trainer's checkpoints.
// 1 -- initial version: added a key-value pair for the checkpoint version info, added
// distributed state key to save all local state collected from distributed workers.
static const size_t trainerCheckpointVersion = 1;
} }
namespace CNTK namespace CNTK
@ -307,15 +317,22 @@ namespace CNTK
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState) void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState)
{ {
auto learnersState = m_parameterLearners->CreateCheckpoint(); auto learnersState = m_parameterLearners->CreateCheckpoint();
if (!m_distributed) if (!m_distributed)
return Save(modelFilePath, learnersState, externalState); return Save(modelFilePath, learnersState, externalState);
auto compositeFunction = dynamic_cast<CompositeFunction*>(m_combinedTrainingFunction.get());
Dictionary state;
state[internalWorkerStateKey] = compositeFunction->GetInternalState(); // this is the local worker's state.
state[externalWorkerStateKey] = externalState;
// Collect distrbuted external state. // Collect distrbuted external state.
DistributedCommunicatorPtr communicator = MPICommunicator(); DistributedCommunicatorPtr communicator = MPICommunicator();
communicator->Barrier(); communicator->Barrier();
std::vector<DictionaryPtr> remoteState; std::vector<DictionaryPtr> remoteState;
communicator->Gather(externalState, remoteState, communicator->Workers()); communicator->Gather(state, remoteState, communicator->Workers());
Dictionary aggregatedState; Dictionary aggregatedState;
for (const auto& w : communicator->Workers()) for (const auto& w : communicator->Workers())
@ -324,19 +341,21 @@ namespace CNTK
} }
if (communicator->CurrentWorker().IsMain()) if (communicator->CurrentWorker().IsMain())
Save(modelFilePath, learnersState, aggregatedState); Save(modelFilePath, learnersState, externalState, aggregatedState);
// all workers need to sync up after saving model to avoid read-after-write hazard // all workers need to sync up after saving model to avoid read-after-write hazard
// i.e. one worker is in the middle of write while another tries to read // i.e. one worker is in the middle of write while another tries to read
communicator->Barrier(); communicator->Barrier();
} }
void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState) void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState, const Dictionary& distributedState)
{ {
std::wstring tempModelFile = modelFilePath + L".tmp"; std::wstring tempModelFile = modelFilePath + L".tmp";
Dictionary state; Dictionary state;
state[versionPropertyName] = trainerCheckpointVersion;
state[learnersPropertyName] = learnerState; state[learnersPropertyName] = learnerState;
state[externalStatePropertyName] = externalState; state[externalStatePropertyName] = externalState;
state[distributedStatePropertyName] = distributedState;
m_combinedTrainingFunction->SaveModel(tempModelFile); m_combinedTrainingFunction->SaveModel(tempModelFile);
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
@ -359,25 +378,57 @@ namespace CNTK
Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath)); Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath));
size_t version = 0;
if (checkpoint.Contains(versionPropertyName))
version = checkpoint[versionPropertyName].Value<size_t>();
auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>(); auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>(); auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();
m_parameterLearners->RestoreFromCheckpoint(learnerState);
if (!m_distributed) if (!m_distributed)
{ {
m_parameterLearners->RestoreFromCheckpoint(learnerState);
return externalState; return externalState;
} }
m_parameterLearners->RestoreFromCheckpoint(learnerState); // this ensures that nobody will start writing to the model/checkpoint files, until
// everybody is done reading them.
DistributedCommunicatorPtr communicator = MPICommunicator(); DistributedCommunicatorPtr communicator = MPICommunicator();
communicator->Barrier(); communicator->Barrier();
auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank); auto mainWorkerId = std::to_wstring(0);
auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank);
if (externalState.Contains(key)) // before version 1, there was no distributed state per se. Instead, the external state
// contained a dictionary of worker-specific external states.
if (version == 0)
{
auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId;
return externalState[key].Value<Dictionary>(); return externalState[key].Value<Dictionary>();
else }
return externalState[std::to_wstring(0)].Value<Dictionary>();
Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>();
if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId))
{
return externalState;
}
// the checkpoint contains internal state for this worker.
Dictionary localState = distributedState[localWorkerId].Value<Dictionary>();
auto internalState = localState[internalWorkerStateKey].Value<Dictionary>();
auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction);
if (compositeFunction == nullptr)
RuntimeError("Combined training function is not a CompositeFunction.");
// this assumes the compositeFunction (restored form a checkpoint made by the main node) and
// the internal worker state both have identical UIDs.
compositeFunction->SetInternalState(internalState);
return localState[externalWorkerStateKey].Value<Dictionary>();
} }
double Trainer::PreviousMinibatchLossAverage() const double Trainer::PreviousMinibatchLossAverage() const

Просмотреть файл

@ -8,7 +8,6 @@
#include "stdafx.h" #include "stdafx.h"
#include "CNTKLibrary.h" #include "CNTKLibrary.h"
#include <fstream> #include <fstream>
#include "Utils.h"
namespace CNTK namespace CNTK
{ {

Просмотреть файл

@ -11,10 +11,10 @@
namespace Microsoft { namespace MSR { namespace CNTK { namespace Microsoft { namespace MSR { namespace CNTK {
CPURNGHandle::CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset) CPURNGHandle::CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
: RNGHandle(deviceId) : RNGHandle(deviceId),
m_generator(seed)
{ {
m_generator.reset(new std::mt19937_64(seed)); m_generator.discard(offset);
m_generator->discard(offset);
} }
}}} }}}

Просмотреть файл

@ -20,12 +20,11 @@ public:
std::mt19937_64& Generator() std::mt19937_64& Generator()
{ {
return *m_generator; return m_generator;
} }
private: private:
std::unique_ptr<std::mt19937_64> m_generator; std::mt19937_64 m_generator;
// TODO: why is this a ptr?
}; };
}}} }}}

Просмотреть файл

@ -9,6 +9,7 @@
#include "Common.h" #include "Common.h"
using namespace CNTK; using namespace CNTK;
using namespace std;
using namespace std::placeholders; using namespace std::placeholders;
extern bool Is1bitSGDAvailable(); extern bool Is1bitSGDAvailable();
@ -35,14 +36,23 @@ namespace
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize; const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
const size_t totalNumberOfSamples = numSamplesPerSweep * numSweepsToTrainWith; const size_t totalNumberOfSamples = numSamplesPerSweep * numSweepsToTrainWith;
const std::wstring g_attributeNameRngSeed = L"rngSeed";
const std::wstring g_attributeNameRngOffset = L"rngOffset";
inline MinibatchSourcePtr GetMinibatchSource(const FeedForwardClassifier& classifier)
{
return TextFormatMinibatchSource(g_inputFile,
{ { g_featureStreamName, classifier.inputDim },
{ g_labelsStreamName, classifier.ouputDim } },
totalNumberOfSamples, true);
}
void LoopBasedOnSamples(const std::wstring& name, const DeviceDescriptor& device, std::function<DistributedLearnerPtr(LearnerPtr)> factory, const FeedForwardClassifier& classifier) void LoopBasedOnSamples(const std::wstring& name, const DeviceDescriptor& device, std::function<DistributedLearnerPtr(LearnerPtr)> factory, const FeedForwardClassifier& classifier)
{ {
printf("Training loop thru samples with %ls.\n", name.c_str()); printf("Training loop thru samples with %ls.\n", name.c_str());
auto minibatchSource = TextFormatMinibatchSource(g_inputFile, auto minibatchSource = GetMinibatchSource(classifier);
{ { g_featureStreamName, classifier.inputDim }, { g_labelsStreamName, classifier.ouputDim } },
totalNumberOfSamples,
true);
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName); auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName); auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
@ -135,3 +145,92 @@ void TestFrameMode()
} }
sync->Barrier(); sync->Barrier();
} }
void TestDistributedCheckpointing()
{
std::vector<DeviceDescriptor> devices;
if (ShouldRunOnCpu())
devices.push_back(DeviceDescriptor::CPUDevice());
if (ShouldRunOnGpu())
devices.push_back(DeviceDescriptor::GPUDevice(0));
auto sync = MPICommunicator();
auto numWorkers = sync->Workers().size();
auto workerRank = sync->CurrentWorker().m_globalRank;
for (auto device : devices)
{
auto ff = BuildFeedForwardClassifier(device);
ff.output = Dropout(ff.output, 0.5);
ff.trainingLoss = CNTK::CrossEntropyWithSoftmax(ff.output, ff.labels, L"lossFunction");
ff.prediction = CNTK::ClassificationError(ff.output, ff.labels, L"classificationError");
{
auto& attributes = ff.output->RootFunction()->Attributes();
size_t seed = attributes[g_attributeNameRngSeed].Value<size_t>();
// Check that (1) the seed is in the attributes dictionary and
// (2) the auto-generated seed value reflects the workerRank.
if (numWorkers > 1 && seed % numWorkers != workerRank)
ReportFailure("Unexpected seed value");
}
auto learner = SGDLearner(ff.output->Parameters(), LearningRatePerSampleSchedule(0.02));
auto distributedLearner = CreateDataParallelDistributedLearner(MPICommunicator(), learner, 0);
auto trainer = CreateTrainer(ff.output, ff.trainingLoss, ff.prediction, { distributedLearner });
auto minibatchSource = GetMinibatchSource(ff);
auto featureStreamInfo = minibatchSource->StreamInfo(g_featureStreamName);
auto labelStreamInfo = minibatchSource->StreamInfo(g_labelsStreamName);
vector<double> expectedLoss(100);
for (int i = 0; i < 100; i++)
{
if (i % 10 == 0)
{
auto checkpoint = minibatchSource->GetCheckpointState();
trainer->SaveCheckpoint(L"distributed_checkpoint_test." + to_wstring(i), checkpoint);
}
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
unordered_map<Variable, MinibatchData> minibatch = { { ff.features, minibatchData[featureStreamInfo] },{ ff.labels, minibatchData[labelStreamInfo] } };
trainer->TrainMinibatch(minibatch, device);
expectedLoss[i] = trainer->PreviousMinibatchLossAverage();
}
for (int i = 0; i < 100; i++)
{
if (i % 10 == 0)
{
auto checkpoint = trainer->RestoreFromCheckpoint(L"distributed_checkpoint_test." + to_wstring(i));
minibatchSource->RestoreFromCheckpoint(checkpoint);
auto& attributes = ff.output->RootFunction()->Attributes();
size_t seed = attributes[g_attributeNameRngSeed].Value<size_t>();
size_t offset = attributes[g_attributeNameRngOffset].Value<size_t>();
// Check that the worker-specific seed value was properly restored from the checkpoint.
if (numWorkers > 1 && seed % numWorkers != workerRank)
ReportFailure("Unexpected seed value");
// Check the offset and verify that it changes depending on the number of processed minibatches.
if (offset != i * minibatchSize * ff.inputDim)
ReportFailure("Unexpected seed value");
}
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
unordered_map<Variable, MinibatchData> minibatch = { { ff.features, minibatchData[featureStreamInfo] },{ ff.labels, minibatchData[labelStreamInfo] } };
trainer->TrainMinibatch(minibatch, device);
auto loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
}
}
sync->Barrier();
}

Просмотреть файл

@ -21,6 +21,7 @@ void MNISTClassifierTests();
void TrainSequenceToSequenceTranslator(); void TrainSequenceToSequenceTranslator();
void TrainTruncatedLSTMAcousticModelClassifier(); void TrainTruncatedLSTMAcousticModelClassifier();
void TestFrameMode(); void TestFrameMode();
void TestDistributedCheckpointing();
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
@ -58,6 +59,8 @@ int main(int argc, char *argv[])
TestFrameMode(); TestFrameMode();
TestDistributedCheckpointing();
std::string testsPassedMsg = "\nCNTKv2Library-Distribution tests: Passed\n"; std::string testsPassedMsg = "\nCNTKv2Library-Distribution tests: Passed\n";
printf("%s", testsPassedMsg.c_str()); printf("%s", testsPassedMsg.c_str());

Просмотреть файл

@ -438,6 +438,10 @@ Test module "V2LibraryTests" has passed with:
Test case "SerializationSuite/CheckpointingWithStatefulNodesInGPU" has passed Test case "SerializationSuite/CheckpointingWithStatefulNodesInGPU" has passed
Test case "SerializationSuite/CheckpointingWithStatefulNodesAndExplicitSeedsOnCPU" has passed
Test case "SerializationSuite/CheckpointingWithStatefulNodesAndExplicitSeedsOnGPU" has passed
Test suite "FeedForwardSuite" has passed with: Test suite "FeedForwardSuite" has passed with:
6 test cases out of 6 passed 6 test cases out of 6 passed
8 assertions out of 8 passed 8 assertions out of 8 passed

Просмотреть файл

@ -815,6 +815,74 @@ void TestCheckpointingWithStatefulNodes(const DeviceDescriptor& device)
} }
} }
void TestCheckpointingWithStatefulNodesAndExplicitSeeds(const DeviceDescriptor& device)
{
auto featureStreamName = L"features";
auto labelsStreamName = L"labels";
size_t inputDim = 784;
size_t numOutputClasses = 10;
auto features = InputVariable({ inputDim }, false /*isSparse*/, DataType::Float, featureStreamName);
auto labels = InputVariable({ numOutputClasses }, DataType::Float, labelsStreamName);
auto net1 = BuildFFClassifierNet(features, numOutputClasses, device, 1);
auto net2 = net1->Clone(ParameterCloningMethod::Clone, { { features , features } });
auto net3 = net1->Clone(ParameterCloningMethod::Clone, { { features , features } });
auto trainer1 = BuildTrainer(Dropout(net1, 0.5, 123), labels);
auto trainer2 = BuildTrainer(Dropout(net2, 0.5, 123), labels);
auto trainer3 = BuildTrainer(Dropout(net3, 0.5, 321), labels);
const size_t minibatchSize = 50;
const size_t maxSamples = 150;
auto minibatchSource = TextFormatMinibatchSource(L"Train-28x28_cntk_text.txt", { { featureStreamName, inputDim },{ labelsStreamName, numOutputClasses } }, 2 * maxSamples, false);
auto featureStreamInfo = minibatchSource->StreamInfo(features);
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
for (int i = 0; i < maxSamples; i+=minibatchSize)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
trainer1->TrainMinibatch(minibatch, device);
trainer2->TrainMinibatch(minibatch, device);
trainer3->TrainMinibatch(minibatch, device);
auto loss1 = trainer1->PreviousMinibatchLossAverage();
auto loss2 = trainer2->PreviousMinibatchLossAverage();
auto loss3 = trainer3->PreviousMinibatchLossAverage();
FloatingPointCompare(loss1, loss2, "Training loss does not match expectation");
BOOST_TEST((abs(loss1 - loss2) <= abs(loss2 - loss3)));
}
trainer1->SaveCheckpoint(L"seeded_stateful_nodes.model");
auto state = minibatchSource->GetCheckpointState();
vector<double> expectedLoss;
for (int i = 0; i < maxSamples; i += minibatchSize)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
trainer1->TrainMinibatch(minibatch, device);
expectedLoss.push_back(trainer1->PreviousMinibatchLossAverage());
}
trainer1->RestoreFromCheckpoint(L"seeded_stateful_nodes.model");
minibatchSource->RestoreFromCheckpoint(state);
for (int i = 0; i*minibatchSize < maxSamples; i++)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
unordered_map<Variable, MinibatchData> minibatch = { { features, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] } };
trainer1->TrainMinibatch(minibatch, device);
double loss = trainer1->PreviousMinibatchLossAverage();
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
}
}
void TestLoadingModelFromMemoryBuffer() void TestLoadingModelFromMemoryBuffer()
{ {
ifstream modelFileStream("batch.norm.no.sample.count.v2.bin", ifstream::binary); ifstream modelFileStream("batch.norm.no.sample.count.v2.bin", ifstream::binary);
@ -992,6 +1060,18 @@ BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesInGPU)
TestCheckpointingWithStatefulNodes(DeviceDescriptor::GPUDevice(0)); TestCheckpointingWithStatefulNodes(DeviceDescriptor::GPUDevice(0));
} }
BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesAndExplicitSeedsOnCPU)
{
TestCheckpointingWithStatefulNodesAndExplicitSeeds(DeviceDescriptor::CPUDevice());
}
BOOST_AUTO_TEST_CASE(CheckpointingWithStatefulNodesAndExplicitSeedsOnGPU)
{
if (ShouldRunOnGpu())
TestCheckpointingWithStatefulNodesAndExplicitSeeds(DeviceDescriptor::GPUDevice(0));
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
}} }}

Просмотреть файл

@ -283,7 +283,7 @@
" with C.layers.default_options(initial_state = 0.1):\n", " with C.layers.default_options(initial_state = 0.1):\n",
" m = C.layers.Recurrence(C.layers.LSTM(N))(x)\n", " m = C.layers.Recurrence(C.layers.LSTM(N))(x)\n",
" m = C.ops.sequence.last(m)\n", " m = C.ops.sequence.last(m)\n",
" m = C.layers.Dropout(0.2)(m)\n", " m = C.layers.Dropout(0.2, seed=1)(m)\n",
" m = cntk.layers.Dense(1)(m)\n", " m = cntk.layers.Dense(1)(m)\n",
" return m" " return m"
] ]

Просмотреть файл

@ -573,6 +573,9 @@ public:
} }
} }
%ignore CNTK::Dictionary::Keys;
%extend CNTK::Dictionary { %extend CNTK::Dictionary {
PyObject* __getitem__(const wchar_t* key) { PyObject* __getitem__(const wchar_t* key) {
PyObject *DictionaryValueToPy(const CNTK::DictionaryValue&); PyObject *DictionaryValueToPy(const CNTK::DictionaryValue&);

Просмотреть файл

@ -14,6 +14,7 @@ from ..variables import Variable, Record, Constant
from ..ops import parameter, input, placeholder, combine from ..ops import parameter, input, placeholder, combine
from ..ops import times, element_times, convolution, convolution_transpose, pooling, unpooling, batch_normalization, dropout, splice, reshape, sequence, softmax, tanh, reduce_sum, reduce_mean, sqrt from ..ops import times, element_times, convolution, convolution_transpose, pooling, unpooling, batch_normalization, dropout, splice, reshape, sequence, softmax, tanh, reduce_sum, reduce_mean, sqrt
from cntk.internal import _as_tuple from cntk.internal import _as_tuple
from cntk.cntk_py import sentinel_value_for_auto_select_random_seed as SentinelValueForAutoSelectRandomSeed
from .blocks import * from .blocks import *
from .higher_order_layers import * from .higher_order_layers import *
from .blocks import _initializer_for, _get_initial_state_or_default, _INFERRED # helpers from .blocks import _initializer_for, _get_initial_state_or_default, _INFERRED # helpers
@ -1023,7 +1024,10 @@ def MaxUnpooling(filter_shape, # shape of receptive field, e.g. (3,3)
# TODO: should the rate(s) be default_options? # TODO: should the rate(s) be default_options?
def Dropout(dropout_rate=None, keep_prob=None, name=''): def Dropout(dropout_rate=None,
keep_prob=None,
seed = SentinelValueForAutoSelectRandomSeed,
name=''):
''' '''
Layer factory function to create a drop-out layer. Layer factory function to create a drop-out layer.
@ -1043,6 +1047,7 @@ def Dropout(dropout_rate=None, keep_prob=None, name=''):
Args: Args:
dropout_rate (float): probability of dropping out an element, mutually exclusive with ``keep_prob`` dropout_rate (float): probability of dropping out an element, mutually exclusive with ``keep_prob``
keep_prob (float): probability of keeping an element, mutually exclusive with ``dropout_rate`` keep_prob (float): probability of keeping an element, mutually exclusive with ``dropout_rate``
seed (int): random seed.
name (str, defaults to ''): the name of the function instance in the network name (str, defaults to ''): the name of the function instance in the network
Returns: Returns:
@ -1057,7 +1062,7 @@ def Dropout(dropout_rate=None, keep_prob=None, name=''):
dropout_rate = 1-keep_prob dropout_rate = 1-keep_prob
@BlockFunction('Dropout', name) @BlockFunction('Dropout', name)
def dropout_f(x): def dropout_f(x):
return dropout(x, dropout_rate=dropout_rate) return dropout(x, dropout_rate=dropout_rate, seed=seed)
return dropout_f return dropout_f

Просмотреть файл

@ -15,6 +15,7 @@ from cntk.internal import sanitize_input, sanitize_shape, sanitize_axis, sanitiz
from cntk.internal.utils import get_data_type from cntk.internal.utils import get_data_type
from ..axis import Axis from ..axis import Axis
from .. import cntk_py from .. import cntk_py
from ..cntk_py import sentinel_value_for_auto_select_random_seed as SentinelValueForAutoSelectRandomSeed
from ..default_options import get_default_override, default_override_or from ..default_options import get_default_override, default_override_or
TIMES_NO_INFERRED_INPUT_RANK = cntk_py.TimesNoInferredInputRank TIMES_NO_INFERRED_INPUT_RANK = cntk_py.TimesNoInferredInputRank
@ -2383,7 +2384,12 @@ def argmin(x, axis=None, name=''):
####################################################################### #######################################################################
@typemap @typemap
def random_sample(weights, num_samples, allow_duplicates, name=''): def random_sample(
weights,
num_samples,
allow_duplicates,
seed = SentinelValueForAutoSelectRandomSeed,
name=''):
''' '''
Estimates inclusion frequencies for random sampling with or without Estimates inclusion frequencies for random sampling with or without
replacement. replacement.
@ -2404,6 +2410,8 @@ def random_sample(weights, num_samples, allow_duplicates, name=''):
num_samples (int): number of expected samples num_samples (int): number of expected samples
allow_duplicates (bool): If sampling is done allow_duplicates (bool): If sampling is done
with replacement (`True`) or without (`False`). with replacement (`True`) or without (`False`).
seed (int): random seed.
name (:class:`str`, optional): the name of the Function instance in the network.
Returns: Returns:
:class:`~cntk.ops.functions.Function` :class:`~cntk.ops.functions.Function`
@ -2412,14 +2420,15 @@ def random_sample(weights, num_samples, allow_duplicates, name=''):
from cntk.cntk_py import random_sample from cntk.cntk_py import random_sample
weights = sanitize_input(weights) weights = sanitize_input(weights)
return random_sample(weights, num_samples, allow_duplicates, name) return random_sample(weights, num_samples, allow_duplicates, seed, name)
@typemap @typemap
def random_sample_inclusion_frequency( def random_sample_inclusion_frequency(
weights, weights,
num_samples, num_samples,
allow_duplicates, allow_duplicates,
seed = SentinelValueForAutoSelectRandomSeed,
name=''): name=''):
''' '''
For weighted sampling with the specifed sample size (`num_samples`) For weighted sampling with the specifed sample size (`num_samples`)
@ -2438,6 +2447,8 @@ def random_sample_inclusion_frequency(
num_samples (int): number of expected samples num_samples (int): number of expected samples
allow_duplicates (bool): If sampling is done allow_duplicates (bool): If sampling is done
with replacement (`True`) or without (`False`). with replacement (`True`) or without (`False`).
seed (int): random seed.
name (:class:`str`, optional): the name of the Function instance in the network.
Example: Example:
>>> import numpy as np >>> import numpy as np
@ -2470,11 +2481,12 @@ def random_sample_inclusion_frequency(
weights, weights,
num_samples, num_samples,
allow_duplicates, allow_duplicates,
seed,
name) name)
@typemap @typemap
def dropout(x, dropout_rate=0.0, name=''): def dropout(x, dropout_rate=0.0, seed = SentinelValueForAutoSelectRandomSeed, name=''):
''' '''
Each element of the input is independently set to 0 with probabily ``dropout_rate`` Each element of the input is independently set to 0 with probabily ``dropout_rate``
or to 1 / (1 - ``dropout_rate``) times its original value (with probability 1-``dropout_rate``). or to 1 / (1 - ``dropout_rate``) times its original value (with probability 1-``dropout_rate``).
@ -2500,6 +2512,7 @@ def dropout(x, dropout_rate=0.0, name=''):
Args: Args:
x: input tensor x: input tensor
dropout_rate (float, [0,1)): probability that an element of ``x`` will be set to zero dropout_rate (float, [0,1)): probability that an element of ``x`` will be set to zero
seed (int): random seed.
name (:class:`str`, optional): the name of the Function instance in the network name (:class:`str`, optional): the name of the Function instance in the network
Returns: Returns:
@ -2511,7 +2524,7 @@ def dropout(x, dropout_rate=0.0, name=''):
from cntk.cntk_py import dropout from cntk.cntk_py import dropout
x = sanitize_input(x) x = sanitize_input(x)
return dropout(x, dropout_rate, name) return dropout(x, dropout_rate, seed, name)
########################################################################## ##########################################################################
# variables_and_parameters ops # variables_and_parameters ops

Просмотреть файл

@ -186,6 +186,41 @@ def test_op_dropout(shape, dropout_rate, device_id, precision):
assert(abs(resulted_non_zeros - expected_non_zeros) < assert(abs(resulted_non_zeros - expected_non_zeros) <
max_off) max_off)
def test_op_dropout_with_explicit_seed(device_id, precision):
from cntk import combine, dropout, input
value = np.ones(shape=(10,10), dtype=PRECISION_TO_TYPE[precision])
a = input(shape=value.shape,
dtype=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
seed = 123;
dropout_nodes= [
dropout(a, dropout_rate=0.5, seed=seed),
dropout(a, dropout_rate=0.5, seed=seed),
dropout(a, dropout_rate=0.5, seed=seed+1),
dropout(a, dropout_rate=0.5)
]
value.shape = (1, 1) + value.shape
forward_input = {a: value}
results = []
for node in dropout_nodes:
forward, backward = cntk_eval(node,
forward_input,
precision,
cntk_device(device_id),
backward_pass=True)
results.append(forward[node.output])
assert np.allclose(results[0], results[1])
assert not np.allclose(results[0], results[2])
assert not np.allclose(results[0], results[3])
@pytest.mark.parametrize("dropout_rate", [-0.1, 1.0, 100]) @pytest.mark.parametrize("dropout_rate", [-0.1, 1.0, 100])
def test_op_dropout_bad_input(dropout_rate): def test_op_dropout_bad_input(dropout_rate):

Просмотреть файл

@ -114,3 +114,18 @@ def test_random_sample_without_replacement(weights, num_samples, expected_count,
denseResult = times(result, identity) denseResult = times(result, identity)
observed_count = np.sum(denseResult.eval(), 0) observed_count = np.sum(denseResult.eval(), 0)
assert np.allclose(observed_count, expected_count, atol=tolerance) assert np.allclose(observed_count, expected_count, atol=tolerance)
def test_random_sample_with_explicit_seed(device_id, precision):
weights = AA([x for x in range(0, 10)], precision)
identity = np.identity(weights.size)
allow_duplicates = False # sample without replacement
num_samples = 5;
seed = 123
to_dense = lambda x: times(x, identity).eval()
result1 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed))
result2 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed))
result3 = to_dense(random_sample(weights, num_samples, allow_duplicates, seed+1))
result4 = to_dense(random_sample(weights, num_samples, allow_duplicates))
assert np.allclose(result1, result2)
assert not np.allclose(result1, result3)
assert not np.allclose(result1, result4)

Просмотреть файл

@ -62,9 +62,9 @@ def test_convolution_transpose_attributes():
def test_dropout_attributes(): def test_dropout_attributes():
x = C.input( (1, 5, 5) ) x = C.input( (1, 5, 5) )
f = C.dropout(x, 0.5) f = C.dropout(x, 0.5, 42)
d = f.root_function.attributes d = f.root_function.attributes
expected = {'dropoutRate': 0.5} expected = {'dropoutRate': 0.5, 'rngSeed' : 42, 'rngOffset' : 0}
_check(expected, d) _check(expected, d)
def test_slice_attributes(): def test_slice_attributes():