Factor worker rank into the auto-generated seed

This commit is contained in:
Alexey Reznichenko 2017-03-14 12:04:01 +01:00
Родитель 4f575448b4
Коммит e3334128bf
3 изменённых файлов: 29 добавлений и 7 удалений

Просмотреть файл

@ -242,6 +242,8 @@ namespace CNTK
CNTK_API size_t NewUniqueId();
CNTK_API size_t GenerateRandomSeed();
// Internal hooks for testing and higher-level bindings
// These should not be directly called by C++ API users
CNTK_API void EnableReversingTensorShapesInErrorMessages();

Просмотреть файл

@ -28,12 +28,26 @@ namespace CNTK
{
namespace Internal
{
static std::atomic<unsigned long long> s_nextUniqueId(0);
static std::atomic_ullong s_nextUniqueId = ATOMIC_VAR_INIT(0);
size_t NewUniqueId()
{
return s_nextUniqueId++;
}
static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0);
size_t GenerateRandomSeed()
{
DistributedCommunicatorPtr communicator = MPICommunicator();
auto numWorkers = communicator->Workers().size();
auto rank = communicator->CurrentWorker().m_globalRank;
if (numWorkers < 1)
numWorkers = 1;
return (numWorkers * ++s_currentRandomSeed) + rank;
}
std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
void EnableReversingTensorShapesInErrorMessages()
{

Просмотреть файл

@ -1043,8 +1043,10 @@ namespace CNTK
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
if (seed != SentinelValueForAutoSelectRandomSeed)
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
if (seed == SentinelValueForAutoSelectRandomSeed)
seed = Internal::GenerateRandomSeed();
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name);
}
@ -1055,8 +1057,10 @@ namespace CNTK
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
if (seed != SentinelValueForAutoSelectRandomSeed)
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
if (seed == SentinelValueForAutoSelectRandomSeed)
seed = Internal::GenerateRandomSeed();
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name);
}
@ -1066,8 +1070,10 @@ namespace CNTK
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate;
if (seed != SentinelValueForAutoSelectRandomSeed)
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
if (seed == SentinelValueForAutoSelectRandomSeed)
seed = Internal::GenerateRandomSeed();
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name);
}