diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h index d5a918512..f3d37fcae 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h @@ -242,6 +242,8 @@ namespace CNTK CNTK_API size_t NewUniqueId(); + CNTK_API size_t GenerateRandomSeed(); + // Internal hooks for testing and higher-level bindings // These should not be directly called by C++ API users CNTK_API void EnableReversingTensorShapesInErrorMessages(); diff --git a/Source/CNTKv2LibraryDll/Common.cpp b/Source/CNTKv2LibraryDll/Common.cpp index 63ef64f3a..d1a7eacd4 100644 --- a/Source/CNTKv2LibraryDll/Common.cpp +++ b/Source/CNTKv2LibraryDll/Common.cpp @@ -28,12 +28,26 @@ namespace CNTK { namespace Internal { - static std::atomic s_nextUniqueId(0); + static std::atomic_ullong s_nextUniqueId = ATOMIC_VAR_INIT(0); size_t NewUniqueId() { return s_nextUniqueId++; } + static std::atomic_ullong s_currentRandomSeed = ATOMIC_VAR_INIT(0); + + size_t GenerateRandomSeed() + { + DistributedCommunicatorPtr communicator = MPICommunicator(); + auto numWorkers = communicator->Workers().size(); + auto rank = communicator->CurrentWorker().m_globalRank; + + if (numWorkers < 1) + numWorkers = 1; + + return (numWorkers * ++s_currentRandomSeed) + rank; + } + std::atomic s_reverseTensorShapesInErrorMessages(false); void EnableReversingTensorShapesInErrorMessages() { diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 25c134bf8..ee63a8643 100755 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -1043,8 +1043,10 @@ namespace CNTK additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples; additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates; - if (seed != SentinelValueForAutoSelectRandomSeed) - additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed); + if (seed == SentinelValueForAutoSelectRandomSeed) + seed = Internal::GenerateRandomSeed(); + + additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed); return UnaryOp(PrimitiveOpType::RandomSample, operand, std::move(additionalProperties), name); } @@ -1055,8 +1057,10 @@ namespace CNTK additionalProperties[PrimitiveFunction::AttributeNameNumSamples] = numSamples; additionalProperties[PrimitiveFunction::AttributeNameAllowDuplicates] = allowDuplicates; - if (seed != SentinelValueForAutoSelectRandomSeed) - additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed); + if (seed == SentinelValueForAutoSelectRandomSeed) + seed = Internal::GenerateRandomSeed(); + + additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed); return UnaryOp(PrimitiveOpType::RandomSampleInclusionFrequency, operand, std::move(additionalProperties), name); } @@ -1066,8 +1070,10 @@ namespace CNTK auto additionalProperties = Dictionary(); additionalProperties[PrimitiveFunction::AttributeNameDropoutRate] = dropoutRate; - if (seed != SentinelValueForAutoSelectRandomSeed) - additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed); + if (seed == SentinelValueForAutoSelectRandomSeed) + seed = Internal::GenerateRandomSeed(); + + additionalProperties[PrimitiveFunction::AttributeNameRngSeed] = size_t(seed); return UnaryOp(PrimitiveOpType::Dropout, operand, std::move(additionalProperties), name); }