Factor worker rank into the auto-generated seed
This commit is contained in:
Родитель
4f575448b4
Коммит
e3334128bf
|
@ -242,6 +242,8 @@ namespace CNTK
|
|||
|
||||
CNTK_API size_t NewUniqueId();
|
||||
|
||||
CNTK_API size_t GenerateRandomSeed();
|
||||
|
||||
// Internal hooks for testing and higher-level bindings
|
||||
// These should not be directly called by C++ API users
|
||||
CNTK_API void EnableReversingTensorShapesInErrorMessages();
|
||||
|
|
|
@ -28,12 +28,26 @@ namespace CNTK
|
|||
{
|
||||
namespace Internal
|
||||
{
|
||||
static std::atomic<unsigned long long> s_nextUniqueId(0);
|
||||
static std::atomic_ullong s_nextUniqueId = ATOMIC_VAR_INIT(0);
|
||||
size_t NewUniqueId()
|
||||
{
|
||||
return s_nextUniqueId++;
|
||||
}
|
||||
|
||||
static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0);
|
||||
|
||||
size_t GenerateRandomSeed()
|
||||
{
|
||||
DistributedCommunicatorPtr communicator = MPICommunicator();
|
||||
auto numWorkers = communicator->Workers().size();
|
||||
auto rank = communicator->CurrentWorker().m_globalRank;
|
||||
|
||||
if (numWorkers < 1)
|
||||
numWorkers = 1;
|
||||
|
||||
return (numWorkers * ++s_currentRandomSeed) + rank;
|
||||
}
|
||||
|
||||
std::atomic<bool> s_reverseTensorShapesInErrorMessages(false);
|
||||
void EnableReversingTensorShapesInErrorMessages()
|
||||
{
|
||||
|
|
|
@ -1043,8 +1043,10 @@ namespace CNTK
|
|||
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
||||
|
||||
if (seed != SentinelValueForAutoSelectRandomSeed)
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||
seed = Internal::GenerateRandomSeed();
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
|
||||
return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
@ -1055,8 +1057,10 @@ namespace CNTK
|
|||
additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates;
|
||||
|
||||
if (seed != SentinelValueForAutoSelectRandomSeed)
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||
seed = Internal::GenerateRandomSeed();
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
|
||||
return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
@ -1066,8 +1070,10 @@ namespace CNTK
|
|||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate;
|
||||
|
||||
if (seed != SentinelValueForAutoSelectRandomSeed)
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
if (seed == SentinelValueForAutoSelectRandomSeed)
|
||||
seed = Internal::GenerateRandomSeed();
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed);
|
||||
|
||||
return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче