Integrate alrezni/8byte_seed into master
This commit is contained in:
Коммит
e26cde816f
|
@ -763,7 +763,7 @@ namespace CNTK
|
|||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
|
||||
{
|
||||
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
|
||||
unsigned long long offset = 0;
|
||||
uint64_t offset = 0;
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
|
||||
{
|
||||
offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
|
||||
|
@ -1191,9 +1191,9 @@ namespace CNTK
|
|||
return m_perOutputVarArgumentDependencies[output];
|
||||
}
|
||||
|
||||
std::unordered_map<Variable, int64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
|
||||
std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
|
||||
{
|
||||
std::unordered_map<Variable, int64_t> currentBackpropRootsTimeStamps;
|
||||
std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
|
||||
assert(m_computationNetwork != nullptr);
|
||||
|
||||
for (auto& backpropRoot : m_currentBackpropRoots)
|
||||
|
@ -1328,7 +1328,7 @@ namespace CNTK
|
|||
InvalidArgument("Invalid backprop state specified");
|
||||
|
||||
// TODO: Support multiple concurrent backprop states
|
||||
std::unordered_map<Variable, int64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
|
||||
std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
|
||||
if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
|
||||
LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
|
||||
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
|
||||
|
|
|
@ -16,17 +16,17 @@ namespace CNTK
|
|||
class CNTKBackPropState final : public BackPropState
|
||||
{
|
||||
public:
|
||||
CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::unordered_map<Variable, int64_t>& backpropRootsForwardTimeStamps)
|
||||
CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::unordered_map<Variable, uint64_t>& backpropRootsForwardTimeStamps)
|
||||
: BackPropState(function, computeDevice), m_backpropRootsForwardTimeStamps(backpropRootsForwardTimeStamps)
|
||||
{}
|
||||
|
||||
const std::unordered_map<Variable, int64_t>& BackpropRootsForwardTimeStamps() const
|
||||
const std::unordered_map<Variable, uint64_t>& BackpropRootsForwardTimeStamps() const
|
||||
{
|
||||
return m_backpropRootsForwardTimeStamps;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<Variable, int64_t> m_backpropRootsForwardTimeStamps;
|
||||
std::unordered_map<Variable, uint64_t> m_backpropRootsForwardTimeStamps;
|
||||
};
|
||||
typedef std::shared_ptr<CNTKBackPropState> CNTKBackPropStatePtr;
|
||||
|
||||
|
@ -210,7 +210,7 @@ namespace CNTK
|
|||
|
||||
const std::vector<Variable>& GetArgumentDependencies(const Variable& output);
|
||||
|
||||
std::unordered_map<Variable, int64_t> GetCurrentBackpropRootsTimeStamps() const;
|
||||
std::unordered_map<Variable, uint64_t> GetCurrentBackpropRootsTimeStamps() const;
|
||||
|
||||
private:
|
||||
|
||||
|
|
|
@ -46,7 +46,8 @@
|
|||
#define CNTK_MODEL_VERSION_14 14 // axis parameter in OptimizedRNNStackNode
|
||||
#define CNTK_MODEL_VERSION_15 15 // add new nodes: LambdaRankNode and NDCG1Eval
|
||||
#define CNTK_MODEL_VERSION_16 16 // save/load rng state for Dropout and RandomSample nodes.
|
||||
#define CURRENT_CNTK_MODEL_VERSION CNTK_MODEL_VERSION_16
|
||||
#define CNTK_MODEL_VERSION_17 17 // use 8 bytes for rng seeds on both platforms
|
||||
#define CURRENT_CNTK_MODEL_VERSION CNTK_MODEL_VERSION_17
|
||||
|
||||
// helper mode for debugging
|
||||
// If TRACK_GAP_NANS is defined then initialize layout gaps to NaN and do NaN checks. Also do detailed logging of node computations.
|
||||
|
@ -250,7 +251,7 @@ public:
|
|||
{
|
||||
m_evalTimeStamp = 0;
|
||||
}
|
||||
int64_t GetEvalTimeStamp() const
|
||||
uint64_t GetEvalTimeStamp() const
|
||||
{
|
||||
return m_evalTimeStamp;
|
||||
}
|
||||
|
@ -263,18 +264,17 @@ public:
|
|||
|
||||
bool IsOlderThan(const TimeStamp& other) const
|
||||
{
|
||||
// the difference is taken to take into account numeric overflow (which really should never happen for a 64-bit integer... but hey, it's free!)
|
||||
return GetEvalTimeStamp() - other.GetEvalTimeStamp() < 0;
|
||||
return GetEvalTimeStamp() < other.GetEvalTimeStamp();
|
||||
}
|
||||
|
||||
int64_t CreateUniqId() const
|
||||
uint64_t CreateUniqId() const
|
||||
{
|
||||
return ++s_timeStampCounter;
|
||||
}
|
||||
|
||||
private:
|
||||
static atomic_ullong s_timeStampCounter;
|
||||
int64_t m_evalTimeStamp; // this is used to reduce unnecessary recomputation when a different node in the model is reevaluated
|
||||
uint64_t m_evalTimeStamp; // this is used to reduce unnecessary recomputation when a different node in the model is reevaluated
|
||||
};
|
||||
|
||||
// =======================================================================
|
||||
|
@ -374,7 +374,7 @@ public:
|
|||
RpcStringFreeW((RPC_WSTR*) &szUuid);
|
||||
}
|
||||
#else
|
||||
int64_t id = CreateUniqId();
|
||||
uint64_t id = CreateUniqId();
|
||||
std::wstring base = L"AutoName";
|
||||
std::wstringstream sstm;
|
||||
sstm << base.c_str() << id;
|
||||
|
|
|
@ -45,8 +45,7 @@ void RandomSampleNodeBase<ElemType>::Save(File& fstream) const
|
|||
Base::Save(fstream);
|
||||
fstream << m_allowDuplicates;
|
||||
fstream << m_sizeOfSampledSet;
|
||||
fstream << GetRngSeed();
|
||||
fstream << GetRngOffset();
|
||||
RngUser::Save(fstream);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -55,14 +54,7 @@ void RandomSampleNodeBase<ElemType>::Load(File& fstream, size_t modelVersion)
|
|||
Base::Load(fstream, modelVersion);
|
||||
fstream >> m_allowDuplicates;
|
||||
fstream >> m_sizeOfSampledSet;
|
||||
if (modelVersion >= CNTK_MODEL_VERSION_16)
|
||||
{
|
||||
unsigned long seed;
|
||||
unsigned long long offset;
|
||||
fstream >> seed;
|
||||
fstream >> offset;
|
||||
SetRngState(seed, offset);
|
||||
}
|
||||
RngUser::Load(fstream, modelVersion);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -275,23 +267,14 @@ template<class ElemType>
|
|||
void DropoutNode<ElemType>::Save(File& fstream) const
|
||||
{
|
||||
Base::Save(fstream);
|
||||
fstream << GetRngSeed();
|
||||
fstream << GetRngOffset();
|
||||
RngUser::Save(fstream);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DropoutNode<ElemType>::Load(File& fstream, size_t modelVersion)
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
|
||||
if (modelVersion >= CNTK_MODEL_VERSION_16)
|
||||
{
|
||||
unsigned long seed;
|
||||
unsigned long long offset;
|
||||
fstream >> seed;
|
||||
fstream >> offset;
|
||||
SetRngState(seed, offset);
|
||||
}
|
||||
RngUser::Load(fstream, modelVersion);
|
||||
}
|
||||
|
||||
template class DropoutNode<float>;
|
||||
|
|
|
@ -1151,7 +1151,7 @@ class IRngUser
|
|||
{
|
||||
public:
|
||||
virtual RNGHandle& GetRNGHandle(DEVICEID_TYPE deviceId) = 0;
|
||||
virtual void SetRngState(unsigned long seed, unsigned long long offset = 0) = 0;
|
||||
virtual void SetRngState(uint64_t seed, uint64_t offset = 0) = 0;
|
||||
};
|
||||
|
||||
// This implements IRngUser using RNGHandle.
|
||||
|
@ -1167,31 +1167,61 @@ public:
|
|||
}
|
||||
|
||||
// E.g. called from ComputationNetwork to make sure that CNTK running on different nodes will have different seed.
|
||||
void SetRngState(unsigned long seed, unsigned long long offset = 0) override
|
||||
void SetRngState(uint64_t seed, uint64_t offset = 0) override
|
||||
{
|
||||
m_rngSeed = seed;
|
||||
m_rngOffset = offset;
|
||||
m_RNGHandle.reset(); // Reset handle. New handle will be generated with next call of GetRNGHandle(...).
|
||||
}
|
||||
|
||||
unsigned long GetRngSeed() const
|
||||
uint64_t GetRngSeed() const
|
||||
{
|
||||
return m_rngSeed;
|
||||
}
|
||||
|
||||
unsigned long long GetRngOffset() const
|
||||
uint64_t GetRngOffset() const
|
||||
{
|
||||
return m_rngOffset;
|
||||
}
|
||||
|
||||
void UpdateRngOffset(unsigned long long val)
|
||||
void UpdateRngOffset(uint64_t val)
|
||||
{
|
||||
m_rngOffset = val;
|
||||
}
|
||||
|
||||
protected:
|
||||
unsigned long m_rngSeed = 0;
|
||||
unsigned long long m_rngOffset = 0;
|
||||
|
||||
void Load(File& fstream, size_t modelVersion)
|
||||
{
|
||||
if (modelVersion < CNTK_MODEL_VERSION_16)
|
||||
return;
|
||||
|
||||
uint64_t seed;
|
||||
uint64_t offset;
|
||||
|
||||
if (modelVersion == CNTK_MODEL_VERSION_16)
|
||||
{
|
||||
unsigned long seed_16;
|
||||
fstream >> seed_16;
|
||||
seed = seed_16;
|
||||
}
|
||||
else
|
||||
{
|
||||
fstream >> seed;
|
||||
}
|
||||
|
||||
fstream >> offset;
|
||||
SetRngState(seed, offset);
|
||||
}
|
||||
|
||||
void Save(File& fstream) const
|
||||
{
|
||||
fstream << GetRngSeed();
|
||||
fstream << GetRngOffset();
|
||||
}
|
||||
|
||||
uint64_t m_rngSeed = 0;
|
||||
uint64_t m_rngOffset = 0;
|
||||
std::shared_ptr<RNGHandle> m_RNGHandle;
|
||||
};
|
||||
|
||||
|
@ -1216,7 +1246,7 @@ public:
|
|||
RandomSampleNodeBase(DEVICEID_TYPE deviceId, const wstring& name, size_t sizeOfSampledSet = 0, bool allowDuplicates = false)
|
||||
: Base(deviceId, name), m_sizeOfSampledSet(sizeOfSampledSet), m_allowDuplicates(allowDuplicates)
|
||||
{
|
||||
SetRngState((unsigned long)CreateUniqId());
|
||||
SetRngState(CreateUniqId());
|
||||
}
|
||||
|
||||
RandomSampleNodeBase(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
|
@ -2092,7 +2122,7 @@ public:
|
|||
: Base(deviceId, name),
|
||||
m_dropoutRate(0)
|
||||
{
|
||||
SetRngState((unsigned long)CreateUniqId());
|
||||
SetRngState(CreateUniqId());
|
||||
}
|
||||
|
||||
virtual void Save(File& fstream) const override;
|
||||
|
|
|
@ -10,12 +10,12 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
CPURNGHandle::CPURNGHandle(int deviceId, unsigned long seed, unsigned long long offset)
|
||||
CPURNGHandle::CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
|
||||
: RNGHandle(deviceId)
|
||||
{
|
||||
#ifdef _MSC_VER // TODO: check if available under GCC/Linux
|
||||
m_generator.reset(new std::ranlux64_base_01());
|
||||
m_generator->seed(seed);
|
||||
m_generator->seed((unsigned long)seed);
|
||||
#else
|
||||
m_generator.reset(new std::default_random_engine(seed));
|
||||
#endif
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
class CPURNGHandle : public RNGHandle
|
||||
{
|
||||
public:
|
||||
CPURNGHandle(int deviceId, unsigned long seed, unsigned long long offset = 0);
|
||||
CPURNGHandle(int deviceId, uint64_t seed, uint64_t offset = 0);
|
||||
|
||||
#ifdef _MSC_VER // TODO: check if available under GCC/Linux
|
||||
std::ranlux64_base_01& Generator()
|
||||
|
@ -25,6 +25,9 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
// TODO: replace with mt19937_64 once we're on VS2015
|
||||
// (this will require re-generating baselines for
|
||||
// Speech/DNN/Dropout and Speech/HTKDeserializers/DNN/Dropout).
|
||||
std::unique_ptr<std::ranlux64_base_01> m_generator;
|
||||
|
||||
#else
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
GPURNGHandle::GPURNGHandle(int deviceId, unsigned long seed, unsigned long long offset)
|
||||
GPURNGHandle::GPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
|
||||
: RNGHandle(deviceId)
|
||||
{
|
||||
unsigned long long cudaSeed = seed;
|
||||
|
|
|
@ -18,7 +18,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
class GPURNGHandle : public RNGHandle
|
||||
{
|
||||
public:
|
||||
GPURNGHandle(int deviceId, unsigned long seed, unsigned long long offset = 0);
|
||||
GPURNGHandle(int deviceId, uint64_t seed, uint64_t offset = 0);
|
||||
virtual ~GPURNGHandle();
|
||||
|
||||
#ifndef CPUONLY
|
||||
|
|
|
@ -2241,7 +2241,7 @@ void GPUDataTransferer::WaitForCopyCPUToGPUAsync(){}
|
|||
|
||||
#pragma region GPURNGHandle functions
|
||||
|
||||
GPURNGHandle::GPURNGHandle(int deviceId, unsigned long seed, unsigned long long offset)
|
||||
GPURNGHandle::GPURNGHandle(int deviceId, uint64_t seed, uint64_t offset)
|
||||
: RNGHandle(deviceId)
|
||||
{
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
/*static*/ std::shared_ptr<RNGHandle> RNGHandle::Create(DEVICEID_TYPE deviceId, unsigned long seed, unsigned long long offset)
|
||||
/*static*/ std::shared_ptr<RNGHandle> RNGHandle::Create(DEVICEID_TYPE deviceId, uint64_t seed, uint64_t offset)
|
||||
{
|
||||
if (deviceId == CPUDEVICE)
|
||||
{
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
class MATH_API RNGHandle
|
||||
{
|
||||
public:
|
||||
static std::shared_ptr<RNGHandle> Create(DEVICEID_TYPE deviceId, unsigned long seed, unsigned long long offset = 0);
|
||||
static std::shared_ptr<RNGHandle> Create(DEVICEID_TYPE deviceId, uint64_t seed, uint64_t offset = 0);
|
||||
|
||||
virtual ~RNGHandle() {}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче