Integrate alrezni/v2_sweep5 into master
This commit is contained in:
Коммит
020869ff17
|
@ -1 +1 @@
|
|||
Subproject commit 8114febb0f491cd0fec4b60e389672cda8ababfb
|
||||
Subproject commit 4b2396f36b8129d035a0166cd2d1a1e457404249
|
|
@ -121,6 +121,7 @@ namespace CNTK
|
|||
struct MinibatchInfo
|
||||
{
|
||||
bool atEndOfData;
|
||||
bool atEndOfSweep;
|
||||
size_t numberOfSamples;
|
||||
NDArrayViewPtr trainingLossValue;
|
||||
NDArrayViewPtr evalCriterionValue;
|
||||
|
@ -3275,7 +3276,7 @@ namespace CNTK
|
|||
///
|
||||
/// A special value that can be used for the epochSize to indicate that the schedule is sweep-based.
|
||||
///
|
||||
static const size_t EntireSweep = 0;
|
||||
static const size_t FullDataSweep = 0;
|
||||
|
||||
///
|
||||
/// Create a schedule with a constant parameter value.
|
||||
|
@ -3287,7 +3288,7 @@ namespace CNTK
|
|||
/// schedule[0] is used for the first 'epochSize' samples, schedule[1] -- for the second,
|
||||
/// and so on. The last value is then used repeatedly until the end of training.
|
||||
///
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, UnitType unit, size_t epochSize = 1);
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, UnitType unit, size_t epochSize = FullDataSweep);
|
||||
|
||||
///
|
||||
/// Create a schedule using the list of key-value pairs, where the key specifies
|
||||
|
@ -3298,7 +3299,7 @@ namespace CNTK
|
|||
/// the first 100 samples, then '0.1' is used for the second 200 samples,
|
||||
/// after which the values is switched to '0.005'.
|
||||
///
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize = 1);
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize = FullDataSweep);
|
||||
|
||||
///
|
||||
/// Returns a value corresponding to the absolute sample (or sweep)
|
||||
|
@ -3313,7 +3314,7 @@ namespace CNTK
|
|||
///
|
||||
UnitType Unit() const { return m_unit; }
|
||||
|
||||
bool IsSweepBased() const { return m_epochSize == EntireSweep; }
|
||||
bool IsSweepBased() const { return m_epochSize == FullDataSweep; }
|
||||
|
||||
CNTK_API virtual ~TrainingParameterSchedule();
|
||||
|
||||
|
@ -3348,12 +3349,15 @@ namespace CNTK
|
|||
TrainingParameterPerUnitSchedule(T value)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(value, U)
|
||||
{ }
|
||||
|
||||
TrainingParameterPerUnitSchedule(const std::vector<T>& schedule, size_t epochSize = 1)
|
||||
|
||||
TrainingParameterPerUnitSchedule(const std::vector<T>& schedule,
|
||||
size_t epochSize = TrainingParameterSchedule<T>::FullDataSweep)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
|
||||
{ }
|
||||
|
||||
TrainingParameterPerUnitSchedule(const std::vector<std::pair<size_t, T>>& schedule, size_t epochSize = 1)
|
||||
|
||||
|
||||
TrainingParameterPerUnitSchedule(const std::vector<std::pair<size_t, T>>& schedule,
|
||||
size_t epochSize = TrainingParameterSchedule<T>::FullDataSweep)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
|
||||
{ }
|
||||
|
||||
|
@ -3401,13 +3405,13 @@ namespace CNTK
|
|||
ConvertToPerSampleValues();
|
||||
}
|
||||
|
||||
MomentumAsTimeConstantSchedule(const std::vector<double>& schedule, size_t epochSize = 1)
|
||||
MomentumAsTimeConstantSchedule(const std::vector<double>& schedule, size_t epochSize = FullDataSweep)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
|
||||
{
|
||||
ConvertToPerSampleValues();
|
||||
}
|
||||
|
||||
MomentumAsTimeConstantSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = 1)
|
||||
MomentumAsTimeConstantSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = FullDataSweep)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
|
||||
{
|
||||
ConvertToPerSampleValues();
|
||||
|
@ -3424,9 +3428,10 @@ namespace CNTK
|
|||
CNTK_API void ConvertToPerSampleValues();
|
||||
};
|
||||
|
||||
|
||||
///
|
||||
/// A collection of additional options that affect parameter updates and
|
||||
/// are applicable for all standard learners
|
||||
///
|
||||
struct AdditionalLearningOptions
|
||||
{
|
||||
double l1RegularizationWeight = 0.0;
|
||||
|
@ -3452,7 +3457,7 @@ namespace CNTK
|
|||
// Method to update the parameters associated with this learner. By returning false, this method indicates that
|
||||
// learning has stopped for all of the parameters associated with this learner
|
||||
//
|
||||
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) = 0;
|
||||
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd = false) = 0;
|
||||
|
||||
///
|
||||
/// Returns the set of parameters associated with this learner.
|
||||
|
@ -3610,9 +3615,9 @@ namespace CNTK
|
|||
return m_communicator;
|
||||
}
|
||||
|
||||
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t minibatchSampleCount) override
|
||||
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t minibatchSampleCount, bool sweepEnd = false) override
|
||||
{
|
||||
MinibatchInfo info{ false, minibatchSampleCount };
|
||||
MinibatchInfo info{ false, sweepEnd, minibatchSampleCount };
|
||||
return Update(gradientValues, info);
|
||||
}
|
||||
|
||||
|
@ -3726,6 +3731,11 @@ namespace CNTK
|
|||
/// Optimize model parameters using the specified 'arguments' minibatch of training samples.
|
||||
/// Returns false if all parameter learners indicate end of learning (through their Update method's return value).
|
||||
///
|
||||
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// An overload of the TrainMinibatch above that takes a map of variables and their values (as its first argument).
|
||||
///
|
||||
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
|
@ -3735,12 +3745,22 @@ namespace CNTK
|
|||
/// for the 'outputs' for which the ValuePtr mapping was left null by the caller.
|
||||
/// Returns false if all parameter learners indicate end of learning (through their Update method's return value).
|
||||
///
|
||||
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// An overload of the TrainMinibatch above that takes a map of variables and their values (as its first argument).
|
||||
///
|
||||
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer
|
||||
/// Returns the average evaluation criterion value per sample for the tested minibatch of samples
|
||||
///
|
||||
CNTK_API double TestMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// An overload of the TestMinibatch above that takes a map of variables and their values (as its first argument).
|
||||
///
|
||||
CNTK_API double TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
|
@ -3806,8 +3826,8 @@ namespace CNTK
|
|||
const DeviceDescriptor& computeDevice,
|
||||
std::unordered_map<Variable, ValuePtr>& parameterGradients);
|
||||
|
||||
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice);
|
||||
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice);
|
||||
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
||||
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
|
||||
|
||||
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState);
|
||||
|
||||
|
@ -3854,11 +3874,34 @@ namespace std {
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
///
|
||||
/// A struct that combines the minibatch meta-data with the actual minibatch data.
|
||||
/// The former includes the number of sequences and samples in the minibatch,
|
||||
/// as well as the sweep-end flag, which is set to true to indicate that the minibatch
|
||||
/// concludes a data sweep (i.e, it's the last minibatch at the end of the sweep).
|
||||
///
|
||||
struct MinibatchData
|
||||
{
|
||||
size_t m_numSequences;
|
||||
size_t m_numSamples;
|
||||
ValuePtr m_data;
|
||||
MinibatchData() : MinibatchData(nullptr)
|
||||
{}
|
||||
|
||||
// a convenience constructor to allow passing ValuePtr arguments in place
|
||||
// of MinibatchData parameter (e.g., in Trainer::TrainMinibatch)
|
||||
MinibatchData(ValuePtr value) : MinibatchData(value, 0)
|
||||
{}
|
||||
|
||||
MinibatchData(ValuePtr value, size_t numSamples, bool sweepEnd = false)
|
||||
: MinibatchData(value, numSamples, numSamples, sweepEnd)
|
||||
{}
|
||||
|
||||
MinibatchData(ValuePtr value, size_t numSequences, size_t numSamples, bool sweepEnd)
|
||||
: data(value), numberOfSequences(numSequences), numberOfSamples(numSamples), sweepEnd(sweepEnd)
|
||||
{}
|
||||
|
||||
ValuePtr data;
|
||||
size_t numberOfSequences;
|
||||
size_t numberOfSamples;
|
||||
bool sweepEnd;
|
||||
};
|
||||
|
||||
///
|
||||
|
|
|
@ -160,6 +160,9 @@ namespace CNTK
|
|||
enum class PrimitiveOpType : unsigned int;
|
||||
enum class DataType : unsigned int;
|
||||
|
||||
struct MinibatchInfo;
|
||||
struct MinibatchData;
|
||||
|
||||
class Serializer;
|
||||
|
||||
// Similar to make_shared except that it associates a custom deleter with the shared_ptr to ensure
|
||||
|
|
|
@ -84,7 +84,7 @@ namespace CNTK
|
|||
break;
|
||||
|
||||
for (auto& currentStreamKV : computedMeanAndInvStdDevs)
|
||||
CompositeFunction::PopulateComputationNodeValue<float>({ streamToDummyInputVariableMap[currentStreamKV.first], minibatchData[currentStreamKV.first].m_data }, streamToInputNodeMap[currentStreamKV.first], layoutsPopulated);
|
||||
CompositeFunction::PopulateComputationNodeValue<float>({ streamToDummyInputVariableMap[currentStreamKV.first], minibatchData[currentStreamKV.first].data }, streamToInputNodeMap[currentStreamKV.first], layoutsPopulated);
|
||||
|
||||
ComputationNetwork::BumpEvalTimeStamp(allInputNodes);
|
||||
|
||||
|
|
|
@ -147,6 +147,6 @@ namespace CNTK
|
|||
if (info.IsEmpty())
|
||||
return false;
|
||||
|
||||
return m_learner->Update(gradientValues, info.numberOfSamples);
|
||||
return m_learner->Update(gradientValues, info.numberOfSamples, info.atEndOfSweep);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -225,7 +225,7 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
/*virtual*/ bool LearnerBase::Update(unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) /*override*/
|
||||
/*virtual*/ bool LearnerBase::Update(unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd) /*override*/
|
||||
{
|
||||
if (LearningRate(trainingSampleCount) == 0.0)
|
||||
{
|
||||
|
@ -233,7 +233,10 @@ namespace CNTK
|
|||
}
|
||||
|
||||
// make sure trainingSampleCount is a valid value
|
||||
assert(trainingSampleCount > 0);
|
||||
if (trainingSampleCount == 0)
|
||||
{
|
||||
InvalidArgument("Learner::Update(): cannot perform an update with an empty minibatch.");
|
||||
}
|
||||
|
||||
for (const auto& parameter : Parameters())
|
||||
{
|
||||
|
@ -273,7 +276,11 @@ namespace CNTK
|
|||
}
|
||||
m_sampleCount += trainingSampleCount;
|
||||
m_minibatchCount++;
|
||||
// TODO: sweep count also needs to be updated.
|
||||
if (sweepEnd)
|
||||
{
|
||||
m_sweepCount++;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace CNTK
|
|||
class LearnerBase : public Learner
|
||||
{
|
||||
public:
|
||||
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) override final;
|
||||
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd = false) override final;
|
||||
|
||||
virtual Dictionary CreateCheckpoint() override final;
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ namespace CNTK
|
|||
CompositeMinibatchSource::CompositeMinibatchSource(const Dictionary& configuration)
|
||||
: m_epochEndReached(false),
|
||||
m_prevMinibatchSize(0),
|
||||
m_epochSize(MinibatchSource::InfinitelyRepeat),
|
||||
m_maxNumSamplesToRead(MinibatchSource::InfinitelyRepeat),
|
||||
m_randomizedWindow(MinibatchSource::DefaultRandomizationWindow),
|
||||
m_truncationLength(0),
|
||||
m_numWorkers(1),
|
||||
|
@ -136,13 +136,7 @@ namespace CNTK
|
|||
|
||||
const wchar_t* epochSizeConfigurationKey = L"epochSize";
|
||||
if (augmentedConfiguration.Contains(epochSizeConfigurationKey))
|
||||
m_epochSize = augmentedConfiguration[epochSizeConfigurationKey].Value<size_t>();
|
||||
|
||||
if (m_epochSize == MinibatchSource::FullDataSweep)
|
||||
m_epochSize = Microsoft::MSR::CNTK::requestDataSize;
|
||||
// Setting big value, but not the max in order to aviod bit overflow.
|
||||
else if (m_epochSize == MinibatchSource::InfinitelyRepeat)
|
||||
m_epochSize = std::numeric_limits<size_t>::max() / 2;
|
||||
m_maxNumSamplesToRead = augmentedConfiguration[epochSizeConfigurationKey].Value<size_t>();
|
||||
|
||||
const wchar_t* randomizedWindowConfigurationKey = L"randomizationWindow";
|
||||
if (augmentedConfiguration.Contains(randomizedWindowConfigurationKey))
|
||||
|
@ -212,9 +206,24 @@ namespace CNTK
|
|||
epochConfig.m_workerRank = workerRank;
|
||||
epochConfig.m_minibatchSizeInSamples = minibatchSizeInSamples;
|
||||
epochConfig.m_truncationSize = m_truncationLength;
|
||||
epochConfig.m_allowMinibatchesToCrossSweepBoundaries = true;
|
||||
|
||||
if (m_maxNumSamplesToRead == MinibatchSource::FullDataSweep)
|
||||
{
|
||||
epochConfig.m_totalEpochSizeInSamples = Microsoft::MSR::CNTK::requestDataSize;
|
||||
}
|
||||
else if (m_maxNumSamplesToRead == MinibatchSource::InfinitelyRepeat)
|
||||
{
|
||||
// Setting big value, but not the max in order to aviod bit overflow.
|
||||
epochConfig.m_totalEpochSizeInSamples = std::numeric_limits<size_t>::max() / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
epochConfig.m_totalEpochSizeInSamples = m_maxNumSamplesToRead;
|
||||
}
|
||||
|
||||
epochConfig.m_totalEpochSizeInSamples = m_epochSize;
|
||||
epochConfig.m_epochIndex = 0;
|
||||
|
||||
m_matrices.clear();
|
||||
|
||||
std::unordered_set<InputStreamDescription> inputs;
|
||||
|
@ -257,6 +266,7 @@ namespace CNTK
|
|||
newConfig.m_workerRank = workerRank;
|
||||
newConfig.m_minibatchSizeInSamples = minibatchSizeInSamples;
|
||||
newConfig.m_truncationSize = m_truncationLength;
|
||||
newConfig.m_allowMinibatchesToCrossSweepBoundaries = true;
|
||||
|
||||
m_shim->SetConfiguration(newConfig, inputDescriptions);
|
||||
|
||||
|
@ -267,9 +277,12 @@ namespace CNTK
|
|||
|
||||
auto hasData = m_shim->GetMinibatch(m_matrices);
|
||||
m_epochEndReached = m_shim->IsEndOfEpoch();
|
||||
|
||||
if (m_epochEndReached && !hasData)
|
||||
return m_minibatchData;
|
||||
|
||||
bool hasReachedSweepEnd = m_shim->IsEndOfSweep();
|
||||
|
||||
for (const auto& s: m_streamInfos)
|
||||
{
|
||||
auto input = m_matrices.GetInput(s.m_name);
|
||||
|
@ -293,7 +306,7 @@ namespace CNTK
|
|||
size_t numSamples = input.pMBLayout->GetActualNumSamples();
|
||||
size_t numSequences = input.pMBLayout->GetNumSequences();
|
||||
|
||||
m_minibatchData[currentStreamInfo] = { numSequences, numSamples, minibatchValuePtr };
|
||||
m_minibatchData[currentStreamInfo] = { minibatchValuePtr, numSequences, numSamples, hasReachedSweepEnd };
|
||||
}
|
||||
else
|
||||
LogicError("Input data of type other than DataType::Float is currently unsupported by the CNTK built-in composite MinibatchSource!");
|
||||
|
|
|
@ -50,7 +50,7 @@ namespace CNTK
|
|||
size_t m_numWorkers;
|
||||
size_t m_workerRank;
|
||||
size_t m_prevMinibatchSize;
|
||||
size_t m_epochSize;
|
||||
size_t m_maxNumSamplesToRead;
|
||||
size_t m_randomizedWindow;
|
||||
size_t m_truncationLength;
|
||||
std::unordered_map<StreamInformation, MinibatchData> m_minibatchData;
|
||||
|
|
|
@ -116,6 +116,29 @@ namespace CNTK
|
|||
return (numSamplesInDataArrayView - numMaskedSamples);
|
||||
}
|
||||
|
||||
static std::unordered_map<Variable, ValuePtr> GetInputs(const std::unordered_map<Variable, MinibatchData>& arguments)
|
||||
{
|
||||
std::unordered_map<Variable, ValuePtr> inputs(arguments.size());
|
||||
for (const auto& kv : arguments)
|
||||
{
|
||||
inputs[kv.first] = kv.second.data;
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
static bool IsAtSweepEnd(const std::unordered_map<Variable, MinibatchData>& arguments)
|
||||
{
|
||||
return std::any_of(arguments.begin(), arguments.end(), [](const std::pair<const Variable, MinibatchData>& kv)
|
||||
{
|
||||
return kv.second.sweepEnd;
|
||||
});
|
||||
}
|
||||
|
||||
double Trainer::TestMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
return TestMinibatch(GetInputs(arguments), computeDevice);
|
||||
}
|
||||
|
||||
double Trainer::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
if (!m_aggregatedEvaluationFunction)
|
||||
|
@ -123,12 +146,26 @@ namespace CNTK
|
|||
|
||||
// TODO: Should we refactor this code that is somewhat similar to the prologue of the TrainMinibatch function
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_aggregatedEvaluationFunction, nullptr }, { m_testSampleCountVar, nullptr } };
|
||||
|
||||
m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice);
|
||||
|
||||
auto sampleCount = GetSampleCount(m_testSampleCountVar, outputs[m_testSampleCountVar]);
|
||||
return (GetScalarValue(outputs[m_aggregatedEvaluationFunction]) / sampleCount);
|
||||
}
|
||||
|
||||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
std::unordered_map<Variable, ValuePtr> outputsToFetch = {};
|
||||
return TrainMinibatch(arguments, outputsToFetch, computeDevice);
|
||||
}
|
||||
|
||||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
if (!m_distributed)
|
||||
return TrainLocalMinibatch(GetInputs(arguments), outputsToFetch, IsAtSweepEnd(arguments), computeDevice);
|
||||
return TrainDistributedMinibatch(GetInputs(arguments), outputsToFetch, IsAtSweepEnd(arguments), computeDevice);
|
||||
}
|
||||
|
||||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
std::unordered_map<Variable, ValuePtr> outputsToFetch = {};
|
||||
|
@ -138,11 +175,11 @@ namespace CNTK
|
|||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
if (!m_distributed)
|
||||
return TrainLocalMinibatch(arguments, outputsToFetch, computeDevice);
|
||||
return TrainDistributedMinibatch(arguments, outputsToFetch, computeDevice);
|
||||
return TrainLocalMinibatch(arguments, outputsToFetch, false, computeDevice);
|
||||
return TrainDistributedMinibatch(arguments, outputsToFetch, false, computeDevice);
|
||||
}
|
||||
|
||||
bool Trainer::TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
bool Trainer::TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
bool emptyMinibatch = arguments.empty() || (arguments.begin()->second == nullptr);
|
||||
if (emptyMinibatch) // Nothing to train with.
|
||||
|
@ -154,10 +191,10 @@ namespace CNTK
|
|||
std::unordered_map<Parameter, NDArrayViewPtr> gradients;
|
||||
for (const auto& parameter : m_combinedTrainingFunction->Parameters())
|
||||
gradients[parameter] = parameterGradients[parameter]->Data();
|
||||
return m_parameterLearners->Update(gradients, m_prevMinibatchNumSamples);
|
||||
return m_parameterLearners->Update(gradients, m_prevMinibatchNumSamples, sweepEnd);
|
||||
}
|
||||
|
||||
bool Trainer::TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
bool Trainer::TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
std::unordered_map<Parameter, NDArrayViewPtr> gradients;
|
||||
auto modelParameters = m_combinedTrainingFunction->Parameters();
|
||||
|
@ -184,7 +221,7 @@ namespace CNTK
|
|||
evalCriterion = m_prevMinibatchAggregateEvalCriterionValue->Data();
|
||||
}
|
||||
|
||||
MinibatchInfo info { arguments.empty(), m_prevMinibatchNumSamples, trainingLoss, evalCriterion };
|
||||
MinibatchInfo info{ arguments.empty(), sweepEnd, m_prevMinibatchNumSamples, trainingLoss, evalCriterion };
|
||||
bool updated = m_parameterLearners->Update(gradients, info);
|
||||
m_prevMinibatchNumSamples = info.numberOfSamples;
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ namespace CNTK
|
|||
if (!minibatchData.empty())
|
||||
{
|
||||
for (auto v : m_modelInputToMinibatchSourceStream)
|
||||
minibatch.insert({ v.first, minibatchData[v.second].m_data });
|
||||
minibatch.insert({ v.first, minibatchData[v.second].data });
|
||||
}
|
||||
|
||||
OnMinibatchStart();
|
||||
|
|
|
@ -241,7 +241,7 @@ namespace CNTK
|
|||
|
||||
template <typename T>
|
||||
TrainingParameterSchedule<T>::TrainingParameterSchedule(T value, UnitType unit)
|
||||
: m_schedule({ make_pair(0, value) }), m_unit(unit), m_epochSize(EntireSweep)
|
||||
: m_schedule({ make_pair(0, value) }), m_unit(unit), m_epochSize(FullDataSweep)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -268,13 +268,9 @@ namespace CNTK
|
|||
template <typename T>
|
||||
void TrainingParameterSchedule<T>::ConstructSchedule(const std::vector<std::pair<size_t, T>>& schedule)
|
||||
{
|
||||
if (m_epochSize == EntireSweep)
|
||||
{
|
||||
//Sweep based schedules are currently not functional (learners don't have sweep info).
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
const auto epochSize = (m_epochSize == EntireSweep) ? 1 : m_epochSize;
|
||||
// In case of the FullDataSweep, the scheduling unit is just 1 sweep,
|
||||
// otherwise, it's the epoch size in samples.
|
||||
const auto unitSize = (m_epochSize == FullDataSweep) ? 1 : m_epochSize;
|
||||
|
||||
if (schedule.size() == 0)
|
||||
RuntimeError("TrainingParameterSchedule::ConstructSchedule : schedule is empty.");
|
||||
|
@ -288,7 +284,7 @@ namespace CNTK
|
|||
RuntimeError("TrainingParameterSchedule::ConstructSchedule : unit count in the 'schedule' argument cannot be 0.");
|
||||
|
||||
unitCount += (pair.first != 0) ? pair.first : 1;
|
||||
m_schedule[epochSize * unitCount] = pair.second;
|
||||
m_schedule[unitSize * unitCount] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -880,14 +876,14 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
bool Learners::Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t sampleInMinibatch)
|
||||
bool Learners::Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t sampleInMinibatch, bool sweepEnd)
|
||||
{
|
||||
bool anyUpdatesPerformed = false;
|
||||
for (auto learner : m_learners)
|
||||
{
|
||||
std::unordered_map<Parameter, NDArrayViewPtr> learnerGradients;
|
||||
GetLearnerGradients(learner, gradientValues, learnerGradients);
|
||||
anyUpdatesPerformed |= learner->Update(learnerGradients, sampleInMinibatch);
|
||||
anyUpdatesPerformed |= learner->Update(learnerGradients, sampleInMinibatch, sweepEnd);
|
||||
}
|
||||
return anyUpdatesPerformed;
|
||||
}
|
||||
|
|
|
@ -501,7 +501,7 @@ namespace CNTK
|
|||
public:
|
||||
explicit Learners(const std::vector<LearnerPtr>& learners);
|
||||
|
||||
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount);
|
||||
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd);
|
||||
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, MinibatchInfo& minibatchInfo);
|
||||
|
||||
std::vector<DictionaryValue> CreateCheckpoint();
|
||||
|
|
|
@ -29,7 +29,7 @@ BlockRandomizer::BlockRandomizer(
|
|||
m_epochSize(SIZE_MAX),
|
||||
m_globalSamplePosition(SIZE_MAX),
|
||||
m_epochStartPosition(0),
|
||||
m_sweepTotalNumberOfSamples(0),
|
||||
m_sweepSizeInSamples(0),
|
||||
m_chunkRandomizer(std::make_shared<ChunkRandomizer>(deserializer, randomizationRangeInSamples)),
|
||||
m_multithreadedGetNextSequences(multithreadedGetNextSequence),
|
||||
m_prefetchedChunk(CHUNKID_MAX),
|
||||
|
@ -43,10 +43,10 @@ BlockRandomizer::BlockRandomizer(
|
|||
m_sequenceRandomizer = std::make_shared<SequenceRandomizer>(verbosity, m_deserializer, m_chunkRandomizer);
|
||||
|
||||
// Calculate total number of samples.
|
||||
m_sweepTotalNumberOfSamples = 0;
|
||||
m_sweepSizeInSamples = 0;
|
||||
for (auto const & chunk : m_deserializer->GetChunkDescriptions())
|
||||
{
|
||||
m_sweepTotalNumberOfSamples += chunk->m_numberOfSamples;
|
||||
m_sweepSizeInSamples += chunk->m_numberOfSamples;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
m_config = config;
|
||||
if (config.m_totalEpochSizeInSamples == requestDataSize)
|
||||
{
|
||||
m_epochSize = m_sweepTotalNumberOfSamples;
|
||||
m_epochSize = m_sweepSizeInSamples;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -92,7 +92,7 @@ void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
// Prepares a new sweep if needed.
|
||||
void BlockRandomizer::PrepareNewSweepIfNeeded(size_t samplePosition)
|
||||
{
|
||||
size_t sweep = samplePosition / m_sweepTotalNumberOfSamples;
|
||||
size_t sweep = samplePosition / m_sweepSizeInSamples;
|
||||
if (m_sweep != sweep)
|
||||
{
|
||||
if (m_verbosity >= Notification)
|
||||
|
@ -115,32 +115,94 @@ Sequences BlockRandomizer::GetNextSequences(size_t globalSampleCount, size_t loc
|
|||
{
|
||||
// Get next sequence descriptions.
|
||||
Sequences result;
|
||||
ClosedOpenChunkInterval windowRange;
|
||||
m_sequenceBuffer.clear();
|
||||
result.m_endOfEpoch = GetNextSequenceDescriptions(globalSampleCount, localSampleCount, m_sequenceBuffer, windowRange);
|
||||
if (m_sequenceBuffer.size() == 0)
|
||||
size_t numGlobalSamplesLoaded = 0, numLocalSamplesLoaded = 0;
|
||||
do
|
||||
{
|
||||
return result;
|
||||
assert(globalSampleCount > numGlobalSamplesLoaded && localSampleCount > numLocalSamplesLoaded);
|
||||
bool atTheSweepBoundary = result.m_endOfSweep;
|
||||
// in case when we continue filling up a minibatch that crosses a sweep boundary,
|
||||
// make sure that it does not exceed the required number of samples. Set the atLeastOnceSequenceNeeded
|
||||
// flag to false.
|
||||
size_t numGlobalSamples = 0, numLocalSamples = 0;
|
||||
std::tie(numGlobalSamples, numLocalSamples) =
|
||||
LoadSequenceData(globalSampleCount - numGlobalSamplesLoaded,
|
||||
localSampleCount - numLocalSamplesLoaded,
|
||||
result, !atTheSweepBoundary);
|
||||
|
||||
if (atTheSweepBoundary && numGlobalSamples == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
numGlobalSamplesLoaded += numGlobalSamples;
|
||||
numLocalSamplesLoaded += numLocalSamples;
|
||||
|
||||
} while (m_config.m_allowMinibatchesToCrossSweepBoundaries &&
|
||||
!result.m_endOfEpoch &&
|
||||
result.m_endOfSweep &&
|
||||
globalSampleCount > numGlobalSamplesLoaded &&
|
||||
localSampleCount > numLocalSamplesLoaded);
|
||||
|
||||
m_cleaner.Clean(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> BlockRandomizer::LoadSequenceData(size_t globalSampleCount, size_t localSampleCount, Sequences& sequences, bool atLeastOneSequenceNeeded)
|
||||
{
|
||||
ClosedOpenChunkInterval windowRange;
|
||||
m_sequenceBuffer.clear();
|
||||
size_t numGlobalSamples = 0, numLocalSamples = 0; // actual number of samples to load (filled in from the sequence descriptions)
|
||||
bool endOfSweep, endOfEpoch;
|
||||
std::tie(endOfSweep, endOfEpoch, numGlobalSamples, numLocalSamples) = GetNextSequenceDescriptions(globalSampleCount, localSampleCount, m_sequenceBuffer, windowRange, atLeastOneSequenceNeeded);
|
||||
sequences.m_endOfSweep |= endOfSweep;
|
||||
sequences.m_endOfEpoch |= endOfEpoch;
|
||||
|
||||
assert(atLeastOneSequenceNeeded || (numGlobalSamples <= globalSampleCount && numLocalSamples <= localSampleCount));
|
||||
|
||||
if (numGlobalSamples == 0)
|
||||
{
|
||||
assert(!atLeastOneSequenceNeeded || sequences.m_endOfEpoch);
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
// Retrieve new data chunks if required.
|
||||
LoadDataChunks(windowRange);
|
||||
|
||||
result.m_data.resize(m_streams.size(), std::vector<SequenceDataPtr>(m_sequenceBuffer.size()));
|
||||
auto& data = sequences.m_data;
|
||||
size_t offset = 0;
|
||||
|
||||
if (data.empty())
|
||||
{
|
||||
data.resize(m_streams.size(), std::vector<SequenceDataPtr>(m_sequenceBuffer.size()));
|
||||
}
|
||||
else
|
||||
{
|
||||
// sequence data is not empty, we're appending new items to exiting
|
||||
// sequence data vectors.
|
||||
offset = data.front().size();
|
||||
for (auto& sequenceDataVector : data)
|
||||
{
|
||||
// make sure that all streams contain the same number of sequences
|
||||
assert(sequenceDataVector.size() == offset);
|
||||
sequenceDataVector.resize(offset + m_sequenceBuffer.size());
|
||||
}
|
||||
}
|
||||
|
||||
auto process = [&](int i) -> void {
|
||||
const auto& description = m_sequenceBuffer[i];
|
||||
std::vector<SequenceDataPtr> sequence;
|
||||
std::vector<SequenceDataPtr> sequenceData;
|
||||
auto it = m_chunks.find(description.m_chunk->m_original->m_id);
|
||||
if (it == m_chunks.end())
|
||||
{
|
||||
LogicError("Invalid chunk requested.");
|
||||
}
|
||||
|
||||
it->second->GetSequence(description.m_id, sequence);
|
||||
it->second->GetSequence(description.m_id, sequenceData);
|
||||
for (int j = 0; j < m_streams.size(); ++j)
|
||||
{
|
||||
result.m_data[j][i] = sequence[j];
|
||||
assert(offset + i < data[j].size());
|
||||
data[j][offset + i] = sequenceData[j];
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -158,18 +220,16 @@ Sequences BlockRandomizer::GetNextSequences(size_t globalSampleCount, size_t loc
|
|||
process(i);
|
||||
}
|
||||
|
||||
m_cleaner.Clean(result);
|
||||
|
||||
// Now it is safe to start the new chunk prefetch.
|
||||
ChunkIdType chunkToPrefetchNext = GetChunkToPrefetch(windowRange);
|
||||
Prefetch(chunkToPrefetchNext);
|
||||
|
||||
return result;
|
||||
return { numGlobalSamples, numLocalSamples };
|
||||
}
|
||||
|
||||
// Get next sequence descriptions for that worker that do not exceed global and local sample count.
|
||||
// Returns true if epoch end is reached.
|
||||
bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange)
|
||||
std::tuple<bool, bool, size_t, size_t> BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange, bool atLeastOneSequenceNeeded)
|
||||
{
|
||||
if (globalSampleCount == 0)
|
||||
LogicError("Global sample count must not be zero.");
|
||||
|
@ -179,17 +239,22 @@ bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size
|
|||
|
||||
PrepareNewSweepIfNeeded(m_globalSamplePosition);
|
||||
|
||||
auto sweepPosition = m_globalSamplePosition % m_sweepSizeInSamples;
|
||||
auto epochEndPosition = m_epochSize + m_epochStartPosition;
|
||||
|
||||
// Check epoch end.
|
||||
if (m_globalSamplePosition >= m_epochSize + m_epochStartPosition)
|
||||
if (m_globalSamplePosition >= epochEndPosition)
|
||||
{
|
||||
return true;
|
||||
auto reachedEndOfEpoch = true;
|
||||
auto reachedEndOfSweep = (m_globalSamplePosition >= m_sweepSizeInSamples) && (sweepPosition == 0);
|
||||
return std::make_tuple(reachedEndOfSweep, reachedEndOfEpoch, 0, 0);
|
||||
}
|
||||
|
||||
// Global sample count should not exceed the epoch.
|
||||
globalSampleCount = std::min(globalSampleCount, m_epochSize + m_epochStartPosition - m_globalSamplePosition);
|
||||
globalSampleCount = std::min(globalSampleCount, epochEndPosition - m_globalSamplePosition);
|
||||
|
||||
// Global sample count should also not exceed the sweep.
|
||||
globalSampleCount = std::min(globalSampleCount, (long)m_sweepTotalNumberOfSamples - m_globalSamplePosition % m_sweepTotalNumberOfSamples);
|
||||
globalSampleCount = std::min(globalSampleCount, m_sweepSizeInSamples - sweepPosition);
|
||||
|
||||
if (globalSampleCount == 0)
|
||||
LogicError("Global sample count must not result in zero.");
|
||||
|
@ -197,12 +262,17 @@ bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size
|
|||
std::function<bool(const RandomizedSequenceDescription*)> isLocalSequence =
|
||||
[this](const RandomizedSequenceDescription* s) { return s->m_chunk->m_chunkId % m_config.m_numberOfWorkers == m_config.m_workerRank; };
|
||||
|
||||
size_t actualNumberOfGlobalSamples = m_sequenceRandomizer->GetNextSequenceDescriptions(
|
||||
size_t actualNumberOfGlobalSamples = 0, actualNumberOfLocalSamples = 0;
|
||||
std::tie(actualNumberOfGlobalSamples, actualNumberOfLocalSamples) = m_sequenceRandomizer->GetNextSequenceDescriptions(
|
||||
globalSampleCount,
|
||||
localSampleCount,
|
||||
isLocalSequence,
|
||||
windowRange,
|
||||
result);
|
||||
result,
|
||||
atLeastOneSequenceNeeded);
|
||||
|
||||
if (actualNumberOfLocalSamples > actualNumberOfGlobalSamples)
|
||||
LogicError("Local sample count cannot be greater than the global sample count.");
|
||||
|
||||
if (m_verbosity >= Debug)
|
||||
fprintf(stderr, "BlockRandomizer::GetNextSequenceDescriptions(): getting %" PRIu64 " sequences for %" PRIu64 "/%" PRIu64 " requested local/global samples in sweep %" PRIu64 "\n",
|
||||
|
@ -211,10 +281,15 @@ bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size
|
|||
globalSampleCount,
|
||||
m_sweep);
|
||||
|
||||
// set "reachedEndOfSweep" to true if the minibatch is last in a sweep
|
||||
auto reachedEndOfSweep = (sweepPosition + actualNumberOfGlobalSamples >= m_sweepSizeInSamples);
|
||||
// set "reachedEndOfEpoch" to true if the current batch is last in an epoch.
|
||||
auto reachedEndOfEpoch = (m_globalSamplePosition + actualNumberOfGlobalSamples >= epochEndPosition);
|
||||
|
||||
// Update the global sample position.
|
||||
m_globalSamplePosition += actualNumberOfGlobalSamples;
|
||||
|
||||
// return true if the current batch is last in an epoch.
|
||||
return m_globalSamplePosition >= m_epochSize + m_epochStartPosition;
|
||||
return std::make_tuple(reachedEndOfSweep, reachedEndOfEpoch, actualNumberOfGlobalSamples, actualNumberOfLocalSamples);
|
||||
}
|
||||
|
||||
// Retrieves chunk data based on the window information provided by SequenceRandomizer
|
||||
|
@ -350,9 +425,9 @@ void BlockRandomizer::SetCurrentSamplePosition(size_t currentSamplePosition)
|
|||
|
||||
// Sets sequence cursor to the sequence that corresponds to the epoch start position.
|
||||
// If last epoch ended in the middle of a sequence, the cursor is moved to the next sequence in the sweep.
|
||||
size_t offsetInSweep = currentSamplePosition % m_sweepTotalNumberOfSamples;
|
||||
size_t offsetInSweep = currentSamplePosition % m_sweepSizeInSamples;
|
||||
size_t newOffset = m_sequenceRandomizer->Seek(offsetInSweep, m_sweep);
|
||||
m_globalSamplePosition = m_sweep * m_sweepTotalNumberOfSamples + newOffset;
|
||||
m_globalSamplePosition = m_sweep * m_sweepSizeInSamples + newOffset;
|
||||
|
||||
// Check if we have some data, if not set to the end of epoch.
|
||||
if (m_config.m_workerRank >= m_chunkRandomizer->GetRandomizedChunks().size())
|
||||
|
|
|
@ -77,8 +77,21 @@ private:
|
|||
// Load data for chunks if needed.
|
||||
void LoadDataChunks(const ClosedOpenChunkInterval& windowRange);
|
||||
|
||||
// Get next sequence descriptions that do not exceed global and local sample count.
|
||||
bool GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange);
|
||||
// Load actual sequence data up to the specified global/local sample count
|
||||
// (or at least one sequence when atLeastOneSequenceNeeded is true),
|
||||
// Returns the total number of global and local samples loaded.
|
||||
std::pair<size_t, size_t> LoadSequenceData(size_t globalSampleCount, size_t localSampleCount, Sequences& sequence, bool atLeastOneSequenceNeeded);
|
||||
|
||||
// Gets the next sequence descriptions with the total number of samples not exceeding
|
||||
// the sample count, when atLeastOneSequenceNeeded is false. Otherwise (when atLeastOneSequenceNeeded is true),
|
||||
// returns at least one sequence description even when its length is greater than the required sample count.
|
||||
// Returns a tuple containing "end of sweep", "end of epoch" flags and
|
||||
// the total numbers of global and local samples to be processed.
|
||||
std::tuple<bool, bool, size_t, size_t> GetNextSequenceDescriptions(size_t globalSampleCount,
|
||||
size_t localSampleCount,
|
||||
std::vector<RandomizedSequenceDescription>& result,
|
||||
ClosedOpenChunkInterval& windowRange,
|
||||
bool atLeastOneSequenceNeeded);
|
||||
|
||||
// Prepares a new sweep if needed.
|
||||
void PrepareNewSweepIfNeeded(size_t samplePosition);
|
||||
|
@ -105,7 +118,7 @@ private:
|
|||
size_t m_sweep;
|
||||
|
||||
// Total number of samples in a sweep.
|
||||
size_t m_sweepTotalNumberOfSamples;
|
||||
size_t m_sweepSizeInSamples;
|
||||
|
||||
IDataDeserializerPtr m_deserializer;
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_currentChunkPosition(CHUNKID_MAX),
|
||||
m_globalSamplePosition(0),
|
||||
m_globalSequencePosition(0),
|
||||
m_totalNumberOfSamples(0),
|
||||
m_sweepSizeInSamples(0),
|
||||
m_currentSequencePositionInChunk(0),
|
||||
m_multithreadedGetNextSequences(multithreadedGetNextSequences),
|
||||
m_cleaner(maxNumberOfInvalidSequences)
|
||||
|
@ -41,7 +41,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
RuntimeError("NoRandomizer: Expected input to contain samples, but the number of successfully read samples was 0.");
|
||||
}
|
||||
|
||||
m_totalNumberOfSamples = sampleCount;
|
||||
m_sweepSizeInSamples = sampleCount;
|
||||
}
|
||||
|
||||
ChunkIdType NoRandomizer::GetChunkIndexOf(size_t samplePosition)
|
||||
|
@ -55,7 +55,7 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
m_config = config;
|
||||
|
||||
if (m_config.m_totalEpochSizeInSamples == requestDataSize)
|
||||
m_config.m_totalEpochSizeInSamples = m_totalNumberOfSamples;
|
||||
m_config.m_totalEpochSizeInSamples = m_sweepSizeInSamples;
|
||||
|
||||
SetCurrentSamplePosition(m_config.m_totalEpochSizeInSamples * config.m_epochIndex);
|
||||
}
|
||||
|
@ -83,39 +83,38 @@ void NoRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t
|
|||
assert(globalSampleCount != 0);
|
||||
assert(localSampleCount != 0);
|
||||
|
||||
if (globalSampleCount > std::numeric_limits<int>::max() ||
|
||||
if (globalSampleCount > std::numeric_limits<int>::max() &&
|
||||
localSampleCount > std::numeric_limits<int>::max())
|
||||
RuntimeError("Global and local size of the minibatch cannot exceed max int.");
|
||||
|
||||
assert(m_sequenceWindow.size() != 0);
|
||||
assert(m_chunkDescriptions[m_currentChunkPosition]->m_numberOfSequences > m_currentSequencePositionInChunk);
|
||||
|
||||
int localSamplesLeft = (int)localSampleCount;
|
||||
int globalSamplesLeft = (int)globalSampleCount;
|
||||
size_t numGlobalSamplesLoaded = 0, numLocalSamplesLoaded = 0;
|
||||
|
||||
result.reserve(localSampleCount);
|
||||
result.clear();
|
||||
|
||||
while (globalSamplesLeft > 0 && localSamplesLeft > 0)
|
||||
while (globalSampleCount > numGlobalSamplesLoaded && localSampleCount > numLocalSamplesLoaded)
|
||||
{
|
||||
const SequenceDescription& sequence = m_sequenceWindow[m_currentSequencePositionInChunk];
|
||||
int sequenceLength = (int)sequence.m_numberOfSamples;
|
||||
auto sequenceLength = sequence.m_numberOfSamples;
|
||||
|
||||
// Let's check whether we need to return this sequence or skip it.
|
||||
bool isLocal = m_globalSequencePosition % m_config.m_numberOfWorkers == m_config.m_workerRank;
|
||||
if (result.empty() ||
|
||||
((localSamplesLeft >= sequenceLength) && (globalSamplesLeft >= sequenceLength)))
|
||||
((localSampleCount - numLocalSamplesLoaded >= sequenceLength) && (globalSampleCount - numGlobalSamplesLoaded >= sequenceLength)))
|
||||
{
|
||||
if (isLocal) // Ok good to add it to the result.
|
||||
{
|
||||
result.push_back(sequence);
|
||||
localSamplesLeft -= sequence.m_numberOfSamples;
|
||||
numLocalSamplesLoaded += sequence.m_numberOfSamples;
|
||||
}
|
||||
}
|
||||
else // otherwise there is no room, return what we have.
|
||||
break;
|
||||
|
||||
globalSamplesLeft -= sequence.m_numberOfSamples;
|
||||
numGlobalSamplesLoaded += sequence.m_numberOfSamples;
|
||||
m_globalSamplePosition += sequence.m_numberOfSamples;
|
||||
m_globalSequencePosition++;
|
||||
|
||||
|
@ -141,25 +140,35 @@ Sequences NoRandomizer::GetNextSequences(size_t globalSampleCount, size_t localS
|
|||
if (m_globalSamplePosition >= endOfEpochPosition)
|
||||
{
|
||||
result.m_endOfEpoch = true;
|
||||
result.m_endOfSweep = (m_globalSamplePosition >= m_sweepSizeInSamples) &&
|
||||
(m_globalSamplePosition % m_sweepSizeInSamples == 0);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check we do not go over epoch.
|
||||
globalSampleCount = std::min(globalSampleCount, endOfEpochPosition - m_globalSamplePosition);
|
||||
|
||||
// Check that we do not go over the sweep.
|
||||
size_t sweepPosition = m_globalSamplePosition % m_totalNumberOfSamples;
|
||||
globalSampleCount = std::min(globalSampleCount, m_totalNumberOfSamples - sweepPosition);
|
||||
if (!m_config.m_allowMinibatchesToCrossSweepBoundaries)
|
||||
{
|
||||
// Cut down the required sample count if we're not allowed to go over the
|
||||
// sweep boundary
|
||||
size_t sweepPosition = m_globalSamplePosition % m_sweepSizeInSamples;
|
||||
globalSampleCount = std::min(globalSampleCount, m_sweepSizeInSamples - sweepPosition);
|
||||
}
|
||||
|
||||
if (globalSampleCount == 0)
|
||||
LogicError("Global sample count must not result in zero.");
|
||||
|
||||
auto sweepIndex = m_globalSamplePosition / m_sweepSizeInSamples;
|
||||
|
||||
m_sequenceBuffer.clear();
|
||||
GetNextSequenceDescriptions(globalSampleCount, localSampleCount, m_sequenceBuffer);
|
||||
|
||||
// m_globalSamplePosition is already shifted in GetNextSequenceDescriptions() by the current minibatch size.
|
||||
// Set the end-of-epoch flag (true when the current batch is last in an epoch).
|
||||
result.m_endOfEpoch = (m_globalSamplePosition >= endOfEpochPosition);
|
||||
result.m_endOfSweep = sweepIndex != m_globalSamplePosition / m_sweepSizeInSamples;
|
||||
|
||||
if (m_sequenceBuffer.size() == 0)
|
||||
{
|
||||
return result;
|
||||
|
@ -229,7 +238,7 @@ void NoRandomizer::SetCurrentSamplePosition(size_t samplePosition)
|
|||
{
|
||||
m_currentSequencePositionInChunk = 0;
|
||||
m_globalSamplePosition = samplePosition;
|
||||
size_t sweepSamplePosition = m_globalSamplePosition % m_totalNumberOfSamples;
|
||||
size_t sweepSamplePosition = m_globalSamplePosition % m_sweepSizeInSamples;
|
||||
|
||||
ChunkIdType chunkIndex = GetChunkIndexOf(sweepSamplePosition);
|
||||
if (chunkIndex != m_currentChunkPosition)
|
||||
|
|
|
@ -88,7 +88,7 @@ private:
|
|||
size_t m_globalSequencePosition;
|
||||
|
||||
// Total number of samples in the sweep.
|
||||
size_t m_totalNumberOfSamples;
|
||||
size_t m_sweepSizeInSamples;
|
||||
|
||||
// Temp buffer to avoid allocations.
|
||||
std::vector<SequenceDescription> m_sequenceBuffer;
|
||||
|
|
|
@ -32,6 +32,11 @@ struct ReaderConfiguration
|
|||
size_t m_workerRank; // Rank of the Open MPI worker, worker rank has to be less than the number of workers
|
||||
size_t m_minibatchSizeInSamples; // Maximum minibatch size for the epoch in samples
|
||||
size_t m_truncationSize; // Truncation size in samples for truncated BPTT mode.
|
||||
|
||||
// This flag indicates whether the minibatches are allowed to overlap the boundary
|
||||
// between sweeps (in which case, they can contain data from different sweeps) or
|
||||
// if they need to be trimmed at the sweep end.
|
||||
bool m_allowMinibatchesToCrossSweepBoundaries{ false };
|
||||
};
|
||||
|
||||
// TODO: Should be deprecated.
|
||||
|
@ -87,6 +92,10 @@ typedef std::shared_ptr<StreamMinibatch> StreamMinibatchPtr;
|
|||
// Represents a single minibatch, that contains information about all streams.
|
||||
struct Minibatch
|
||||
{
|
||||
// Indicates that this minibatch is either adjacent to the data sweep boundary
|
||||
// (-----<minibatch>|---) or crosses the boundary (-----<mini|batch>---).
|
||||
bool m_endOfSweep;
|
||||
|
||||
// Indicates that the end of epoch has been reached.
|
||||
// It is set to true for the last minibatch, there still
|
||||
// can be data in m_data field even if this flag is set.
|
||||
|
@ -95,11 +104,8 @@ struct Minibatch
|
|||
// Minibatch data
|
||||
std::vector<StreamMinibatchPtr> m_data;
|
||||
|
||||
Minibatch() : m_endOfEpoch(false)
|
||||
{
|
||||
}
|
||||
|
||||
Minibatch(bool endOfEpoch) : m_endOfEpoch(endOfEpoch)
|
||||
Minibatch(bool endOfSweep = false, bool endOfEpoch = false)
|
||||
: m_endOfSweep(endOfSweep), m_endOfEpoch(endOfEpoch)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
|
|
@ -28,6 +28,7 @@ ReaderShim<ElemType>::ReaderShim() :
|
|||
m_dataTransferers(2, DataTransfererPtr()),
|
||||
m_currentDataTransferIndex(0),
|
||||
m_endOfEpoch(false),
|
||||
m_endOfSweep(false),
|
||||
m_currentSamplePosition(0),
|
||||
m_reader(nullptr),
|
||||
m_factory(nullptr)
|
||||
|
@ -274,6 +275,7 @@ bool ReaderShim<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
m_currentSamplePosition = m_reader->GetCurrentSamplePosition();
|
||||
|
||||
m_endOfEpoch = result.m_isEndOfEpoch;
|
||||
m_endOfSweep = result.m_isEndOfSweep;
|
||||
if (m_endOfEpoch && !result.m_isDataAvailable)
|
||||
{
|
||||
// No data and end of epoch, simply return.
|
||||
|
@ -357,7 +359,7 @@ typename ReaderShim<ElemType>::PrefetchResult ReaderShim<ElemType>::PrefetchMini
|
|||
|
||||
// If there is no data we can simply return.
|
||||
if (minibatch.m_data.empty())
|
||||
return PrefetchResult{ minibatch.m_endOfEpoch, false };
|
||||
return PrefetchResult{ minibatch.m_endOfSweep, minibatch.m_endOfEpoch, false };
|
||||
|
||||
// Ok we have some data. Let's load it to GPU.
|
||||
// But before we need to make sure that corresponding compute has already finished from the last iteration.
|
||||
|
@ -380,7 +382,7 @@ typename ReaderShim<ElemType>::PrefetchResult ReaderShim<ElemType>::PrefetchMini
|
|||
if (m_dataTransferers[currentDataTransferIndex])
|
||||
m_dataTransferers[currentDataTransferIndex]->RecordCPUToGPUCopy();
|
||||
|
||||
return PrefetchResult{ minibatch.m_endOfEpoch, true };
|
||||
return PrefetchResult{ minibatch.m_endOfSweep, minibatch.m_endOfEpoch, true };
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -100,9 +100,15 @@ public:
|
|||
return m_endOfEpoch;
|
||||
}
|
||||
|
||||
bool IsEndOfSweep() const
|
||||
{
|
||||
return m_endOfSweep;
|
||||
}
|
||||
|
||||
private:
|
||||
struct PrefetchResult
|
||||
{
|
||||
bool m_isEndOfSweep;
|
||||
bool m_isEndOfEpoch;
|
||||
bool m_isDataAvailable;
|
||||
};
|
||||
|
@ -113,6 +119,7 @@ private:
|
|||
ReaderPtr m_reader;
|
||||
ReaderFactory m_factory;
|
||||
bool m_endOfEpoch;
|
||||
bool m_endOfSweep;
|
||||
|
||||
size_t m_numParallelSequences;
|
||||
|
||||
|
|
|
@ -20,8 +20,13 @@ struct Sequences
|
|||
// Indices in the outer vector have to correspond to the stream ids returned from the GetStreamDescriptions().
|
||||
std::vector<std::vector<SequenceDataPtr>> m_data;
|
||||
|
||||
// Indicates whether the returned data comes from a sweep end or
|
||||
// crosses a sweep boundary (and as a result includes sequences
|
||||
// from different sweeps).
|
||||
bool m_endOfSweep{ false };
|
||||
|
||||
// Indicates whether the epoch ends with the data returned.
|
||||
bool m_endOfEpoch = false;
|
||||
bool m_endOfEpoch{ false };
|
||||
};
|
||||
|
||||
class SequenceEnumerator;
|
||||
|
|
|
@ -41,7 +41,7 @@ Minibatch SequencePacker::ReadMinibatch()
|
|||
auto sequences = m_sequenceEnumerator->GetNextSequences(m_globalMinibatchSizeInSamples, m_localMinibatchSizeInSamples);
|
||||
const auto& batch = sequences.m_data;
|
||||
|
||||
Minibatch minibatch(sequences.m_endOfEpoch);
|
||||
Minibatch minibatch(sequences.m_endOfSweep, sequences.m_endOfEpoch);
|
||||
if (batch.empty())
|
||||
return minibatch;
|
||||
|
||||
|
|
|
@ -63,26 +63,28 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
RandomizeNextChunkIfNeeded();
|
||||
}
|
||||
|
||||
// Gets the next randomized sequence descriptions not exceeding the global and local sample count.
|
||||
// Whether the sequence is considered local is defined by the isLocalSequence predicate.
|
||||
// Returns how many global samples have been read.
|
||||
size_t SequenceRandomizer::GetNextSequenceDescriptions(
|
||||
// Gets the next randomized sequence descriptions not exceeding the global and local sample count,
|
||||
// when atLeastOneSequenceNeeded is false. Otherwise (when atLeastOneSequenceNeeded is true),
|
||||
// returns at least one sequence description even when its length is greater than the required sample counts.
|
||||
// Whether a sequence is considered local is defined by the isLocalSequence predicate.
|
||||
// Returns a pair whose first element indicates the number of global samples read,
|
||||
// and second -- the number of local samples read (== sum of number of sample over all elements in the
|
||||
// 'sequences' vector).
|
||||
std::pair<size_t, size_t> SequenceRandomizer::GetNextSequenceDescriptions(
|
||||
size_t globalSampleCount,
|
||||
size_t localSampleCount,
|
||||
const std::function<bool(const RandomizedSequenceDescription*)>& isLocalSequence,
|
||||
ClosedOpenChunkInterval& requiredChunks,
|
||||
std::vector<RandomizedSequenceDescription>& sequences)
|
||||
std::vector<RandomizedSequenceDescription>& sequences,
|
||||
bool atLeastOneSequenceNeeded)
|
||||
{
|
||||
assert(globalSampleCount != 0);
|
||||
assert(localSampleCount != 0);
|
||||
|
||||
if (globalSampleCount > std::numeric_limits<int>::max() ||
|
||||
if (globalSampleCount > std::numeric_limits<int>::max() &&
|
||||
localSampleCount > std::numeric_limits<int>::max())
|
||||
RuntimeError("Global and local size of the minibatch cannot exceed max int.");
|
||||
|
||||
int localSamplesLeft = (int)localSampleCount;
|
||||
int globalSamplesLeft = (int)globalSampleCount;
|
||||
|
||||
// Initialize the range to the current chunk.
|
||||
requiredChunks.m_begin = (ChunkIdType)std::min(m_currentChunkCursor, m_randomizedChunks.size() - 1);
|
||||
requiredChunks.m_end = requiredChunks.m_begin + 1;
|
||||
|
@ -90,9 +92,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
sequences.reserve(localSampleCount);
|
||||
sequences.clear();
|
||||
|
||||
size_t totalSamplesRead = 0;
|
||||
size_t globalSamplesRead = 0, localSamplesRead = 0;
|
||||
while (m_currentChunkCursor < m_randomizedChunks.size() &&
|
||||
globalSamplesLeft > 0 && localSamplesLeft > 0)
|
||||
(localSamplesRead < localSampleCount && globalSamplesRead < globalSampleCount))
|
||||
{
|
||||
size_t sequenceOffsetInsideChunk = m_currentSequenceCursor - m_randomizedChunks[m_currentChunkCursor].m_sequencePositionStart;
|
||||
const RandomizedSequenceDescription* sequence = &m_sequenceWindow[m_currentChunkCursor - m_chunkWindowBegin][sequenceOffsetInsideChunk];
|
||||
|
@ -100,20 +102,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
bool isLocal = isLocalSequence(sequence);
|
||||
|
||||
// Let's check whether we need to return this sequence or skip it.
|
||||
if (sequences.empty() ||
|
||||
((localSamplesLeft >= sequenceLength) && (globalSamplesLeft >= sequenceLength)))
|
||||
if ((sequences.empty() && atLeastOneSequenceNeeded) ||
|
||||
((localSamplesRead + sequenceLength <= localSampleCount) && (globalSamplesRead + sequenceLength <= globalSampleCount)))
|
||||
{
|
||||
if (isLocal) // Ok good to add it to the result.
|
||||
{
|
||||
sequences.push_back(*sequence);
|
||||
localSamplesLeft -= sequenceLength;
|
||||
localSamplesRead += sequenceLength;
|
||||
}
|
||||
// even when the next sequence is not local, somebody else would return it, so
|
||||
// we need to ivalidate the 'atLeastOneSequenceNeeded' flag.
|
||||
atLeastOneSequenceNeeded = false;
|
||||
}
|
||||
else // otherwise there is no room, return what we have.
|
||||
break;
|
||||
|
||||
totalSamplesRead += sequenceLength;
|
||||
globalSamplesLeft -= sequenceLength;
|
||||
globalSamplesRead += sequenceLength;
|
||||
|
||||
// Update the required chunk window.
|
||||
requiredChunks.m_begin = std::min(m_randomizedChunks[m_currentChunkCursor].m_randomizationWindow.m_begin, requiredChunks.m_begin);
|
||||
|
@ -130,7 +134,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
}
|
||||
|
||||
return totalSamplesRead;
|
||||
return { globalSamplesRead, localSamplesRead };
|
||||
}
|
||||
|
||||
// Move the chunk cursor to the next chunk, randomizing more sequences if necessary.
|
||||
|
|
|
@ -45,15 +45,20 @@ public:
|
|||
// If the offset points in the middle of last sequence, the end of the sweep is returned.
|
||||
size_t Seek(size_t sweepSampleOffset, size_t sweep);
|
||||
|
||||
// Gets the next randomized sequence descriptions not exceeding the global and local sample count.
|
||||
// Gets the next randomized sequence descriptions not exceeding the global and local sample count,
|
||||
// when atLeastOneSequenceNeeded is false. Otherwise (when atLeastOneSequenceNeeded is true),
|
||||
// returns at least one sequence description even when its length is greater than the required sample count.
|
||||
// Whether a sequence is considered local is defined by the isLocalSequence predicate.
|
||||
// The return value is how many global samples have been read.
|
||||
size_t GetNextSequenceDescriptions(
|
||||
// Returns a pair whose first element indicates the number of global samples read,
|
||||
// and second -- the number of local samples read (== sum of number of sample over all elements in the
|
||||
// 'sequences' vector).
|
||||
std::pair<size_t, size_t> GetNextSequenceDescriptions(
|
||||
size_t globalSampleCount,
|
||||
size_t localSampleCount,
|
||||
const std::function<bool(const RandomizedSequenceDescription*)>& isLocalSequence,
|
||||
ClosedOpenChunkInterval& requiredChunks,
|
||||
std::vector<RandomizedSequenceDescription>& sequences);
|
||||
std::vector<RandomizedSequenceDescription>& sequences,
|
||||
bool atLeastOneSequenceNeeded = true);
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_MOVE(SequenceRandomizer);
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#define _SCL_SECURE_NO_WARNINGS
|
||||
|
||||
#include <cmath>
|
||||
#include <deque>
|
||||
#include "TruncatedBpttPacker.h"
|
||||
#include "ReaderUtil.h"
|
||||
|
||||
|
@ -38,9 +37,10 @@ public:
|
|||
}
|
||||
|
||||
// Adds a new sequence to the end of the slot.
|
||||
void PushSequence(SequenceDataPtr s)
|
||||
void PushSequence(SequenceDataPtr s, bool endOfSweep)
|
||||
{
|
||||
m_sequences.push_back(s);
|
||||
m_endOfSweepFlags.push_back(endOfSweep);
|
||||
m_length += s->m_numberOfSamples;
|
||||
}
|
||||
|
||||
|
@ -51,13 +51,16 @@ public:
|
|||
}
|
||||
|
||||
// Pops the front sequence at the beginning of the slot.
|
||||
void PopSequence()
|
||||
bool PopSequence()
|
||||
{
|
||||
assert(!m_sequences.empty());
|
||||
m_sampleCursor = 0;
|
||||
m_sampleOffset = 0;
|
||||
m_length -= m_sequences.front()->m_numberOfSamples;
|
||||
m_sequences.pop_front();
|
||||
bool endOfSweepFlag = m_endOfSweepFlags.front();
|
||||
m_endOfSweepFlags.pop_front();
|
||||
return endOfSweepFlag;
|
||||
}
|
||||
|
||||
// Contains the current sample cursor in the first sequence(m_sequences.front()) of the slot.
|
||||
|
@ -72,6 +75,10 @@ public:
|
|||
private:
|
||||
// Prepared sequences.
|
||||
deque<SequenceDataPtr> m_sequences;
|
||||
|
||||
// For each 'in-flight' sequence we keep a flag that indicate whether
|
||||
// the sequence data comes from an the end of a sweep.
|
||||
std::deque<bool> m_endOfSweepFlags;
|
||||
|
||||
// Contains the size of the slot in samples (accumulated over all m_sequences).
|
||||
size_t m_length;
|
||||
|
@ -155,7 +162,7 @@ void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, co
|
|||
|
||||
m_sequenceBufferPerStream.clear();
|
||||
|
||||
// Preparing the buffers.
|
||||
// Preparing the buffers.
|
||||
for (int j = 0; j < m_streamBuffers.size(); ++j)
|
||||
for (int i = 0; i < m_outputStreamDescriptions.size(); ++i)
|
||||
{
|
||||
|
@ -166,25 +173,22 @@ void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, co
|
|||
}
|
||||
}
|
||||
|
||||
// Filling in the initial set of sequences
|
||||
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
|
||||
{
|
||||
ReadSequencesToSlot(slotIndex);
|
||||
}
|
||||
FillOutAvailableSlots();
|
||||
}
|
||||
|
||||
Minibatch TruncatedBPTTPacker::ReadMinibatch()
|
||||
{
|
||||
Minibatch result;
|
||||
FillOutAvailableSlots();
|
||||
|
||||
// Currently all we expect sequences of identical length between different streams,
|
||||
// so it is sufficient to check a single stream only.
|
||||
if (m_sequenceBufferPerStream.front()->NothingToPack())
|
||||
{
|
||||
result.m_endOfEpoch = true;
|
||||
return result;
|
||||
{
|
||||
return Minibatch(/*endOfSweep = */false,/*endOfEpoch = */ true);
|
||||
}
|
||||
|
||||
Minibatch result;
|
||||
|
||||
// Iterating over the streams/slots and packing them into the minibatch.
|
||||
for (size_t streamIndex = 0; streamIndex < m_outputStreamDescriptions.size(); ++streamIndex)
|
||||
{
|
||||
|
@ -192,7 +196,7 @@ Minibatch TruncatedBPTTPacker::ReadMinibatch()
|
|||
size_t sequenceId = 0;
|
||||
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
|
||||
{
|
||||
PackSlot(streamIndex, slotIndex, sequenceId);
|
||||
result.m_endOfSweep |= PackSlot(streamIndex, slotIndex, sequenceId);
|
||||
}
|
||||
|
||||
StreamMinibatchPtr m = make_shared<StreamMinibatch>();
|
||||
|
@ -203,17 +207,18 @@ Minibatch TruncatedBPTTPacker::ReadMinibatch()
|
|||
|
||||
m_currentBufferIndex = (m_currentBufferIndex + 1) % m_numberOfBuffers;
|
||||
|
||||
// Eagerly set the end of epoch flag if all the data have been packed.
|
||||
result.m_endOfEpoch = m_sequenceBufferPerStream.front()->NothingToPack();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Packs a slot of sequences into the minibatch.
|
||||
void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId)
|
||||
bool TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId)
|
||||
{
|
||||
bool containsEndOfSweepSequence = false;
|
||||
auto& slot = m_sequenceBufferPerStream[streamIndex]->m_slots[slotIndex];
|
||||
|
||||
// Fill free space in the slot.
|
||||
ReadSequencesToSlot(slotIndex);
|
||||
|
||||
// Let's see how much samples we need to read.
|
||||
size_t numberOfSamples = min(m_config.m_truncationSize, slot.AvailableNumberOfSamples());
|
||||
if (numberOfSamples == 0)
|
||||
|
@ -223,7 +228,7 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
|
|||
|
||||
// Check that nothing is in the slot any more.
|
||||
assert(slot.IsEmpty());
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t sampleSize = GetSampleSize(m_inputStreamDescriptions[streamIndex]);
|
||||
|
@ -247,7 +252,7 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
|
|||
if (slot.m_sampleCursor >= slot.FrontSequence()->m_numberOfSamples)
|
||||
{
|
||||
// Starting a new sequence. Have to reset current pointers and add it to the minibatch layout.
|
||||
slot.PopSequence();
|
||||
containsEndOfSweepSequence |= slot.PopSequence();
|
||||
|
||||
//Adding next sequence to the minibatch.
|
||||
m_currentLayouts[streamIndex]->AddSequence(
|
||||
|
@ -290,7 +295,7 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
|
|||
// Cleaning up the last sequence we have just read if needed.
|
||||
if (slot.m_sampleCursor >= slot.FrontSequence()->m_numberOfSamples)
|
||||
{
|
||||
slot.PopSequence();
|
||||
containsEndOfSweepSequence |= slot.PopSequence();
|
||||
}
|
||||
|
||||
// Adding the last gap if there is one.
|
||||
|
@ -302,35 +307,59 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
|
|||
numberOfSamples,
|
||||
m_config.m_truncationSize);
|
||||
}
|
||||
|
||||
return containsEndOfSweepSequence;
|
||||
}
|
||||
|
||||
void TruncatedBPTTPacker::FillOutAvailableSlots()
|
||||
{
|
||||
// Filling out any available spaces
|
||||
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
|
||||
{
|
||||
ReadSequencesToSlot(slotIndex);
|
||||
}
|
||||
}
|
||||
|
||||
void TruncatedBPTTPacker::ReadSequencesToSlot(size_t slotIndex)
|
||||
{
|
||||
const auto& slot = m_sequenceBufferPerStream.front()->m_slots[slotIndex];
|
||||
while (m_config.m_truncationSize >= slot.AvailableNumberOfSamples())
|
||||
const auto& firstStreamSlot = m_sequenceBufferPerStream.front()->m_slots[slotIndex];
|
||||
while (m_config.m_truncationSize >= firstStreamSlot.AvailableNumberOfSamples())
|
||||
{
|
||||
// We need a single sequence, potentially we can request (m_truncationSize - slot.AvailableNumberOfSamples())
|
||||
// to be more efficient. In reality the truncation size usually is less the sequence size.
|
||||
// Bptt always operates on a local timeline, so we do not limit the global minibatch count.
|
||||
auto s = m_sequenceEnumerator->GetNextSequences(SIZE_MAX, 1);
|
||||
const auto& sequences = m_sequenceEnumerator->GetNextSequences(SIZE_MAX, 1);
|
||||
|
||||
// assert that number of input streams == number of output streams --
|
||||
// this does not have to be the case in general, but the current
|
||||
// implementation makes this implicit assumption, so let's make it
|
||||
// explicit instead until we can get rid of it altogether.
|
||||
assert(sequences.m_endOfEpoch || sequences.m_data.size() == m_outputStreamDescriptions.size());
|
||||
|
||||
const auto& data = sequences.m_data;
|
||||
|
||||
// Adding sequence to the slot for all streams.
|
||||
for (size_t i = 0; i < s.m_data.size(); ++i)
|
||||
for (size_t streamIndex = 0; streamIndex < data.size(); ++streamIndex)
|
||||
{
|
||||
assert(s.m_data[i].size() == 1);
|
||||
assert(data[streamIndex].size() == 1);
|
||||
|
||||
const auto& streamSequenceDataVector = data[streamIndex];
|
||||
auto& slot = m_sequenceBufferPerStream[streamIndex]->m_slots[slotIndex];
|
||||
|
||||
// Check that all sequences are of the same length.
|
||||
if (s.m_data.front().front()->m_numberOfSamples != s.m_data[i].front()->m_numberOfSamples)
|
||||
if (data.front().front()->m_numberOfSamples != streamSequenceDataVector.front()->m_numberOfSamples)
|
||||
{
|
||||
RuntimeError("For BPTT sequences between different input stream should have the same length.");
|
||||
}
|
||||
|
||||
slot.PushSequence(streamSequenceDataVector.front(), sequences.m_endOfSweep);
|
||||
|
||||
m_sequenceBufferPerStream[i]->m_slots[slotIndex].PushSequence(s.m_data[i].front());
|
||||
assert(firstStreamSlot.AvailableNumberOfSamples() == slot.AvailableNumberOfSamples());
|
||||
}
|
||||
|
||||
if (s.m_endOfEpoch)
|
||||
if (sequences.m_endOfEpoch)
|
||||
{
|
||||
break;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,8 +30,13 @@ public:
|
|||
virtual void SetConfiguration(const ReaderConfiguration& config, const std::vector<MemoryProviderPtr>& memoryProviders) override;
|
||||
|
||||
private:
|
||||
// Iterates over all (m_parallelNumberOfSequences) slots,
|
||||
// pulling in and filling out those slots with new sequence data,
|
||||
// for which AvailableNumberOfSamples (= current size in samples) < m_truncationSize.
|
||||
void FillOutAvailableSlots();
|
||||
|
||||
// Reads sequences to slot with the specified index.
|
||||
// Number of slots = m_parallelNumberOfSequences
|
||||
// Number of slots = m_parallelNumberOfSequences.
|
||||
void ReadSequencesToSlot(size_t slotIndex);
|
||||
|
||||
// Packs a slot into the data buffer.
|
||||
|
@ -39,7 +44,9 @@ private:
|
|||
// For each new input, sequence id is reset to 0, and incremented each time
|
||||
// a sequence is added to the layout. This allows layouts corresponding to different
|
||||
// inputs to have consistent sequence ids.
|
||||
void PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId);
|
||||
// Returns a boolean indicating if a packed data contains a sequence
|
||||
// (i.e., sequence tail) that was read last in a data sweep.
|
||||
bool PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId);
|
||||
|
||||
virtual MBLayoutPtr CreateMBLayout(const StreamBatch& batch)
|
||||
{
|
||||
|
|
|
@ -22,7 +22,7 @@ TestDataDir=$TEST_RUN_DIR/TestData
|
|||
mkdir $TestDataDir
|
||||
cp -R $DataSourceDir/MNIST/v0/Train-28x28_cntk_text.txt $TestDataDir || exit $?
|
||||
cp -R $DataSourceDir/CIFAR/v0/cifar-10-batches-py $TestDataDir || exit $?
|
||||
cp -R $TEST_DIR/../../Simple2d/Data/SimpleDataTrain_cntk_text.txt $TestDataDir || exit $?
|
||||
cp -R $TEST_DIR/../../Simple2d/Data/SimpleDataT*_cntk_text.txt $TestDataDir || exit $?
|
||||
cp -R $TEST_DIR/../../Text/SequenceClassification/Data/Train.ctf $TestDataDir || exit $?
|
||||
cp -R $TEST_DIR/../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf $TestDataDir || exit $?
|
||||
cp -R $TEST_DIR/../../../../Examples/Speech/AN4/Data/* $TestDataDir || exit $?
|
||||
|
|
|
@ -16,7 +16,7 @@ def test_cntk_103_mnist_feedforwardnetwork_noErrors(nb):
|
|||
for output in cell['outputs'] if output.output_type == "error"]
|
||||
assert errors == []
|
||||
|
||||
expectedEvalErrorByDeviceId = { -1: 1.90, 0: 1.85 }
|
||||
expectedEvalErrorByDeviceId = { -1: 1.67, 0: 1.71 }
|
||||
|
||||
def test_cntk_103_mnist_feedforwardnetwork_evalCorrect(nb, device_id):
|
||||
testCell = [cell for cell in nb.cells
|
||||
|
|
|
@ -31,6 +31,6 @@ def test_cntk_202_language_understanding_trainerror(nb):
|
|||
pass
|
||||
except KeyError:
|
||||
pass
|
||||
expectedMetrics = [2.8, 1.9, 2.2, 2.3]
|
||||
expectedMetrics = [2.8, 1.9, 2.2, 2.0]
|
||||
# TODO tighten tolerances
|
||||
assert numpy.allclose(expectedMetrics, metrics, atol=0.2)
|
||||
|
|
|
@ -1266,8 +1266,8 @@ Legacy configuration is used for truncated BPTT mode, please adapt the config to
|
|||
Set current path to: C:/repo/cntk_github6/CNTK/Tests/EndToEndTests/UnitTests/ReaderTests
|
||||
|
||||
Test module "ReaderTests" has passed with:
|
||||
122 test cases out of 122 passed
|
||||
22235579 assertions out of 22235579 passed
|
||||
127 test cases out of 127 passed
|
||||
22629927 assertions out of 22629927 passed
|
||||
|
||||
Test suite "ReaderTestSuite" has passed with:
|
||||
92 test cases out of 92 passed
|
||||
|
@ -1566,11 +1566,8 @@ Test module "ReaderTests" has passed with:
|
|||
2 assertions out of 2 passed
|
||||
|
||||
Test suite "ReaderLibTests" has passed with:
|
||||
18 test cases out of 18 passed
|
||||
2265 assertions out of 2265 passed
|
||||
|
||||
Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with:
|
||||
16 assertions out of 16 passed
|
||||
23 test cases out of 23 passed
|
||||
396613 assertions out of 396613 passed
|
||||
|
||||
Test case "ReaderLibTests/CheckSetCurrentCursorForRandomizers" has passed with:
|
||||
12 assertions out of 12 passed
|
||||
|
@ -1589,8 +1586,23 @@ Test module "ReaderTests" has passed with:
|
|||
|
||||
Test case "ReaderLibTests/BlockRandomizerInstantiate" has passed
|
||||
|
||||
Test case "ReaderLibTests/BlockRandomizerOneEpoch" has passed with:
|
||||
68 assertions out of 68 passed
|
||||
Test case "ReaderLibTests/TestRandomization_FirstEpoch" has passed with:
|
||||
1476 assertions out of 1476 passed
|
||||
|
||||
Test case "ReaderLibTests/TestRandomization_SecondEpoch" has passed with:
|
||||
1476 assertions out of 1476 passed
|
||||
|
||||
Test case "ReaderLibTests/TestRandomization_TwoSweeps" has passed with:
|
||||
3372 assertions out of 3372 passed
|
||||
|
||||
Test case "ReaderLibTests/TestRandomization_TwoSweeps_WithSequences" has passed with:
|
||||
196668 assertions out of 196668 passed
|
||||
|
||||
Test case "ReaderLibTests/TestRandomization_TwoSweeps_AllowToCrossSweepBoundary" has passed with:
|
||||
3204 assertions out of 3204 passed
|
||||
|
||||
Test case "ReaderLibTests/TestRandomization_TwoSweeps_AllowToCrossSweepBoundary_WithSequences" has passed with:
|
||||
189756 assertions out of 189756 passed
|
||||
|
||||
Test case "ReaderLibTests/BlockRandomizerOneEpochWithChunks1" has passed with:
|
||||
68 assertions out of 68 passed
|
||||
|
@ -1598,8 +1610,8 @@ Test module "ReaderTests" has passed with:
|
|||
Test case "ReaderLibTests/BlockRandomizerOneEpochWithChunks2" has passed with:
|
||||
128 assertions out of 128 passed
|
||||
|
||||
Test case "ReaderLibTests/BlockRandomizerChaosMonkey" has passed with:
|
||||
1836 assertions out of 1836 passed
|
||||
Test case "ReaderLibTests/RandomizerChaosMonkey" has passed with:
|
||||
300 assertions out of 300 passed
|
||||
|
||||
Test case "ReaderLibTests/BlockRandomizerOneEpochLegacyRandomization" has passed with:
|
||||
68 assertions out of 68 passed
|
||||
|
@ -1607,6 +1619,9 @@ Test module "ReaderTests" has passed with:
|
|||
Test case "ReaderLibTests/NoRandomizerOneEpoch" has passed with:
|
||||
35 assertions out of 35 passed
|
||||
|
||||
Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with:
|
||||
16 assertions out of 16 passed
|
||||
|
||||
Test case "ReaderLibTests/DefaultCorpusDescriptor" has passed with:
|
||||
2 assertions out of 2 passed
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
|
||||
#include "NoRandomizer.h"
|
||||
#include "DataDeserializer.h"
|
||||
#include "BlockRandomizer.h"
|
||||
|
@ -79,7 +79,7 @@ private:
|
|||
vector<vector<float>> m_sequenceData;
|
||||
|
||||
public:
|
||||
MockDeserializer(size_t numChunks, size_t numSequencesPerChunks, vector<float>& data, uint32_t sequenceLength = 1)
|
||||
MockDeserializer(size_t numChunks, size_t numSequencesPerChunks, const vector<float>& data, uint32_t sequenceLength = 1)
|
||||
: m_numChunks(numChunks),
|
||||
m_numSequencesPerChunk(numSequencesPerChunks),
|
||||
m_sampleLayout(make_shared<TensorShape>(1)),
|
||||
|
@ -170,49 +170,7 @@ void BlockRandomizerInstantiateTest(bool prefetch)
|
|||
{
|
||||
vector<float> data;
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(0, 0, data);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers)
|
||||
{
|
||||
size_t chunkSizeInSamples = 10000;
|
||||
size_t sweepNumberOfSamples = 500000;
|
||||
uint32_t maxSequenceLength = 300;
|
||||
size_t randomizationWindow = chunkSizeInSamples * 5;
|
||||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
auto blockRandomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto noRandomizer = make_shared<NoRandomizer>(deserializer, false);
|
||||
|
||||
auto test = [](SequenceEnumeratorPtr r, size_t epochSize)
|
||||
{
|
||||
auto firstEpoch = ReadFullEpoch(r, epochSize, 0);
|
||||
auto firstCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(firstCursor, firstEpoch.size());
|
||||
|
||||
auto secondEpoch = ReadFullEpoch(r, epochSize, 1);
|
||||
auto secondCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(secondCursor - firstCursor, secondEpoch.size());
|
||||
|
||||
auto thirdEpoch = ReadFullEpoch(r, epochSize, 2);
|
||||
auto thirdCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(thirdCursor - secondCursor, thirdEpoch.size());
|
||||
|
||||
auto anotherSecondEpoch = ReadFullEpoch(r, epochSize, 1);
|
||||
auto anotherSecondCursor = r->GetCurrentSamplePosition();
|
||||
|
||||
BOOST_CHECK_EQUAL(anotherSecondCursor, secondCursor);
|
||||
};
|
||||
|
||||
// Inside sweep
|
||||
size_t epochSize = 50000;
|
||||
test(blockRandomizer, epochSize);
|
||||
test(noRandomizer, epochSize);
|
||||
|
||||
// Between sweeps
|
||||
epochSize = (size_t)(sweepNumberOfSamples / 1.5);
|
||||
test(blockRandomizer, epochSize);
|
||||
test(noRandomizer, epochSize);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch, false);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(CheckSetCurrentCursorForRandomizers)
|
||||
|
@ -223,11 +181,11 @@ BOOST_AUTO_TEST_CASE(CheckSetCurrentCursorForRandomizers)
|
|||
size_t randomizationWindow = chunkSizeInSamples * 5;
|
||||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
auto expectedBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto expectedNo = make_shared<NoRandomizer>(deserializer);
|
||||
auto expectedBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
auto expectedNo = make_shared<NoRandomizer>(deserializer, false);
|
||||
|
||||
auto underTestBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto unterTestNo = make_shared<NoRandomizer>(deserializer);
|
||||
auto underTestBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
auto unterTestNo = make_shared<NoRandomizer>(deserializer, false);
|
||||
|
||||
auto test = [](SequenceEnumeratorPtr expected, SequenceEnumeratorPtr underTest, size_t epochSize)
|
||||
{
|
||||
|
@ -292,7 +250,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToEarlierEpochBetweenSweeps)
|
|||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
// Let's randomize complete sweep, so that we have a baseline.
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
|
||||
// Let's read all sequences from the first three sweeps in the randomized order.
|
||||
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
|
||||
|
@ -330,7 +288,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToEarlierEpochInTheSweep)
|
|||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
// Let's randomize complete sweep, so that we have a baseline.
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
|
||||
// Let's read all sequences from the first three sweeps in the randomized order.
|
||||
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
|
||||
|
@ -361,7 +319,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToSameEpochInTheSweep)
|
|||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
// Let's randomize complete sweep, so that we have a baseline.
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
|
||||
// Let's read all sequences from the first three sweeps in the randomized order.
|
||||
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
|
||||
|
@ -388,7 +346,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToSameEpochInBigRandomizationWindow)
|
|||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
// Let's randomize complete sweep, so that we have a baseline.
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
|
||||
// Let's read all sequences from the first three sweeps in the randomized order.
|
||||
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
|
||||
|
@ -421,45 +379,221 @@ BOOST_AUTO_TEST_CASE(BlockRandomizerInstantiate)
|
|||
BlockRandomizerInstantiateTest(true);
|
||||
}
|
||||
|
||||
void BlockRandomizerOneEpochTest(bool prefetch)
|
||||
void OneEpochRandomizationTest(SequenceEnumerator& randomizer, size_t sweepSize, const EpochConfiguration& epochConfig, const vector<float>& expectedOutput, size_t sequenceLength = 1)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
|
||||
auto epochSize = epochConfig.m_totalEpochSizeInSamples;
|
||||
auto mbSize = epochConfig.m_minibatchSizeInSamples;
|
||||
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch);
|
||||
BOOST_ASSERT(epochSize == expectedOutput.size());
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = 1;
|
||||
epochConfiguration.m_workerRank = 0;
|
||||
epochConfiguration.m_minibatchSizeInSamples = 0;
|
||||
epochConfiguration.m_totalEpochSizeInSamples = data.size();
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
randomizer->StartEpoch(epochConfiguration);
|
||||
randomizer.StartEpoch(epochConfig);
|
||||
|
||||
vector<float> expected { 6, 3, 1, 5, 9, 0, 4, 2, 7, 8 };
|
||||
BOOST_CHECK_EQUAL(data.size(), expected.size());
|
||||
vector<float> actual;
|
||||
for (int i = 0; i < data.size() + 1; i++)
|
||||
for (int totalSamplesRead = 0; totalSamplesRead < epochSize;)
|
||||
{
|
||||
Sequences sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1 - (i / data.size()));
|
||||
if (i < data.size())
|
||||
Sequences sequences = randomizer.GetNextSequences(mbSize, mbSize);
|
||||
BOOST_ASSERT(sequences.m_data.size() == 1); // only one input stream
|
||||
auto& stream = sequences.m_data[0];
|
||||
auto numSampleRead = 0;
|
||||
for (auto& sequence : stream)
|
||||
{
|
||||
auto& data2 = reinterpret_cast<DenseSequenceData&>(*sequences.m_data[0][0]);
|
||||
BOOST_CHECK_EQUAL(data2.m_numberOfSamples, 1u);
|
||||
actual.push_back(*((float*)data2.GetDataBuffer()));
|
||||
auto numSamples = sequence->m_numberOfSamples;
|
||||
numSampleRead += numSamples;
|
||||
auto& data = reinterpret_cast<DenseSequenceData&>(*sequence);
|
||||
actual.reserve(actual.size() + numSamples);
|
||||
std::copy_n(((float*)data.GetDataBuffer()), numSamples, std::back_inserter(actual));
|
||||
}
|
||||
BOOST_CHECK_EQUAL(sequences.m_endOfEpoch, (data.size() <= i + 1));
|
||||
|
||||
auto expectedSize = std::min(epochSize - totalSamplesRead, mbSize);
|
||||
if (!epochConfig.m_allowMinibatchesToCrossSweepBoundaries)
|
||||
{
|
||||
expectedSize = std::min(sweepSize - totalSamplesRead % sweepSize, expectedSize);
|
||||
}
|
||||
|
||||
// at least one sequence is returned in case when mbSize < sequenceLength
|
||||
expectedSize = std::max(expectedSize, sequenceLength);
|
||||
BOOST_REQUIRE(numSampleRead <= std::max(mbSize, sequenceLength));
|
||||
if (sequenceLength == 1)
|
||||
BOOST_REQUIRE(numSampleRead == expectedSize);
|
||||
else
|
||||
BOOST_REQUIRE(expectedSize - numSampleRead < sequenceLength);
|
||||
|
||||
BOOST_REQUIRE(sequences.m_endOfEpoch == (totalSamplesRead + numSampleRead == epochSize));
|
||||
BOOST_REQUIRE(sequences.m_endOfSweep == (totalSamplesRead / sweepSize != (totalSamplesRead + numSampleRead) / sweepSize));
|
||||
|
||||
totalSamplesRead += numSampleRead;
|
||||
}
|
||||
BOOST_CHECK_EQUAL_COLLECTIONS(expected.begin(), expected.end(),
|
||||
|
||||
for (int i = 0; i < 3; i++)
|
||||
{
|
||||
auto numSamples = i + 1;
|
||||
Sequences sequences = randomizer.GetNextSequences(numSamples, numSamples);
|
||||
BOOST_REQUIRE(sequences.m_data.size() == 0);
|
||||
BOOST_REQUIRE(sequences.m_endOfEpoch == true);
|
||||
BOOST_REQUIRE(sequences.m_endOfSweep == (epochSize % sweepSize == 0));
|
||||
}
|
||||
|
||||
BOOST_REQUIRE_EQUAL_COLLECTIONS(expectedOutput.begin(), expectedOutput.end(),
|
||||
actual.begin(), actual.end());
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpoch)
|
||||
void TestRandomization(EpochConfiguration& epochConfiguration, IDataDeserializerPtr deserializer, size_t sweepSize, const vector<float>& expectedRandomized, const vector<float>& expectedNotRandomized, size_t sequenceLength = 1)
|
||||
{
|
||||
BlockRandomizerOneEpochTest(false);
|
||||
BlockRandomizerOneEpochTest(true);
|
||||
BlockRandomizer randomizer1(0, SIZE_MAX, deserializer, /*prefetch =*/ false);
|
||||
BlockRandomizer randomizer2(0, SIZE_MAX, deserializer, /*prefetch =*/ true);
|
||||
NoRandomizer randomizer3(deserializer);
|
||||
|
||||
BlockRandomizer randomizer4(0, SIZE_MAX, deserializer, /*prefetch =*/ false, false, /*multithreadedGetNextSequences =*/ true);
|
||||
BlockRandomizer randomizer5(0, SIZE_MAX, deserializer, /*prefetch =*/ true, false, /*multithreadedGetNextSequences =*/ true);
|
||||
NoRandomizer randomizer6(deserializer, /*multithreadedGetNextSequences =*/ true);
|
||||
|
||||
epochConfiguration.m_numberOfWorkers = 1;
|
||||
epochConfiguration.m_workerRank = 0;
|
||||
epochConfiguration.m_totalEpochSizeInSamples = expectedRandomized.size();
|
||||
|
||||
for (int i = 1; i <= epochConfiguration.m_totalEpochSizeInSamples + 1; i++)
|
||||
{
|
||||
epochConfiguration.m_minibatchSizeInSamples = i;
|
||||
OneEpochRandomizationTest(randomizer1, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
|
||||
OneEpochRandomizationTest(randomizer2, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
|
||||
OneEpochRandomizationTest(randomizer3, sweepSize, epochConfiguration, expectedNotRandomized, sequenceLength);
|
||||
|
||||
OneEpochRandomizationTest(randomizer4, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
|
||||
OneEpochRandomizationTest(randomizer5, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
|
||||
OneEpochRandomizationTest(randomizer6, sweepSize, epochConfiguration, expectedNotRandomized, sequenceLength);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestRandomization_FirstEpoch)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8 };
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
|
||||
TestRandomization(epochConfiguration, mockDeserializer, data.size(), expected, data);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestRandomization_SecondEpoch)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
vector<float> expected{ 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_epochIndex = 1;
|
||||
|
||||
TestRandomization(epochConfiguration, mockDeserializer, data.size(), expected, data);
|
||||
}
|
||||
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
|
||||
|
||||
auto sweepSize = data.size();
|
||||
data.reserve(2 * sweepSize);
|
||||
std::copy_n(data.begin(), sweepSize, std::back_inserter(data));
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
|
||||
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expected, data);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps_WithSequences)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
|
||||
|
||||
for (int seqLength = 2; seqLength <= 10; seqLength++)
|
||||
{
|
||||
vector<float> expectedRandomized;
|
||||
vector<float> expectedNotRandomized;
|
||||
for (auto f : expected) {
|
||||
std::fill_n(back_inserter(expectedRandomized), seqLength, f);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2 * data.size(); i++) {
|
||||
std::fill_n(back_inserter(expectedNotRandomized), seqLength, data[i % data.size()]);
|
||||
}
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data, seqLength);
|
||||
|
||||
auto sweepSize = data.size() * seqLength;
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
|
||||
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expectedRandomized, expectedNotRandomized, seqLength);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps_AllowToCrossSweepBoundary)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
|
||||
|
||||
auto sweepSize = data.size();
|
||||
data.reserve(2 * sweepSize);
|
||||
std::copy_n(data.begin(), sweepSize, std::back_inserter(data));
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
epochConfiguration.m_allowMinibatchesToCrossSweepBoundaries = true;
|
||||
|
||||
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expected, data);
|
||||
}
|
||||
|
||||
|
||||
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps_AllowToCrossSweepBoundary_WithSequences)
|
||||
{
|
||||
vector<float> data(10);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
|
||||
|
||||
for (int seqLength = 2; seqLength <= 10; seqLength++)
|
||||
{
|
||||
vector<float> expectedRandomized;
|
||||
vector<float> expectedNotRandomized;
|
||||
for (auto f : expected) {
|
||||
std::fill_n(back_inserter(expectedRandomized), seqLength, f);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2 * data.size(); i++) {
|
||||
std::fill_n(back_inserter(expectedNotRandomized), seqLength, data[i % data.size()]);
|
||||
}
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data, seqLength);
|
||||
|
||||
auto sweepSize = data.size() * seqLength;
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
epochConfiguration.m_allowMinibatchesToCrossSweepBoundaries = true;
|
||||
|
||||
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expectedRandomized, expectedNotRandomized, seqLength);
|
||||
}
|
||||
}
|
||||
|
||||
void BlockRandomizerOneEpochWithChunks1Test(bool prefetch)
|
||||
|
@ -468,7 +602,7 @@ void BlockRandomizerOneEpochWithChunks1Test(bool prefetch)
|
|||
iota(data.begin(), data.end(), 0.0f);
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
|
||||
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, 4, mockDeserializer, prefetch);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, 4, mockDeserializer, prefetch, false);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = 1;
|
||||
|
@ -510,7 +644,7 @@ void BlockRandomizerOneEpochWithChunks2Test(bool prefetch)
|
|||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(10, 2, data);
|
||||
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, 18, mockDeserializer, prefetch);
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, 18, mockDeserializer, prefetch, false);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = 1;
|
||||
|
@ -548,65 +682,75 @@ BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochWithChunks2)
|
|||
BlockRandomizerOneEpochWithChunks2Test(true);
|
||||
}
|
||||
|
||||
void BlockRandomizerChaosMonkeyTest(bool prefetch)
|
||||
void RandomizerChaosMonkeyTest(SequenceEnumerator& randomizer, size_t sweepSize, int seed)
|
||||
{
|
||||
const int sequenceLength = 3;
|
||||
const int seed = 42;
|
||||
const int numChunks = 100;
|
||||
const int numSequencesPerChunk = 10;
|
||||
const int windowSize = 18;
|
||||
vector<float> data(numChunks * numSequencesPerChunk);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
std::mt19937 rng(seed);
|
||||
|
||||
boost::random::uniform_int_distribution<int> distr(1, 10);
|
||||
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(numChunks, numSequencesPerChunk, data, sequenceLength);
|
||||
|
||||
auto randomizer = make_shared<BlockRandomizer>(0, windowSize, mockDeserializer, prefetch);
|
||||
boost::random::uniform_int_distribution<int> distr(1, 100);
|
||||
|
||||
for (int t = 0; t < 100; t++)
|
||||
{
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = distr(rng);
|
||||
do
|
||||
{
|
||||
epochConfiguration.m_workerRank = distr(rng) - 1;
|
||||
}
|
||||
while (epochConfiguration.m_numberOfWorkers <= epochConfiguration.m_workerRank);
|
||||
epochConfiguration.m_workerRank = distr(rng) % epochConfiguration.m_numberOfWorkers;
|
||||
|
||||
epochConfiguration.m_minibatchSizeInSamples = 0; // don't care
|
||||
epochConfiguration.m_totalEpochSizeInSamples = data.size() / distr(rng);
|
||||
epochConfiguration.m_totalEpochSizeInSamples = sweepSize * distr(rng) / distr(rng);
|
||||
epochConfiguration.m_epochIndex = distr(rng);
|
||||
randomizer->StartEpoch(epochConfiguration);
|
||||
epochConfiguration.m_allowMinibatchesToCrossSweepBoundaries = (distr(rng) % 2 == 0);
|
||||
randomizer.StartEpoch(epochConfiguration);
|
||||
|
||||
auto epochStart = epochConfiguration.m_epochIndex * epochConfiguration.m_totalEpochSizeInSamples;
|
||||
auto epochEnd = epochStart + epochConfiguration.m_totalEpochSizeInSamples;
|
||||
auto numSweeps = epochEnd / sweepSize - epochStart / sweepSize;
|
||||
|
||||
auto sweepCount = 0;
|
||||
int samplesToGet = 0;
|
||||
for (int i = 0; i < epochConfiguration.m_totalEpochSizeInSamples + 1; i += samplesToGet)
|
||||
for (;;)
|
||||
{
|
||||
samplesToGet = distr(rng);
|
||||
Sequences sequences = randomizer->GetNextSequences(samplesToGet, samplesToGet);
|
||||
Sequences sequences = randomizer.GetNextSequences(samplesToGet, samplesToGet);
|
||||
|
||||
if (sequences.m_endOfSweep)
|
||||
sweepCount++;
|
||||
|
||||
// In case end of epoch/decimation/single sequence -> skip the mbSize check.
|
||||
if (sequences.m_endOfEpoch || sequences.m_data.empty() || sequences.m_data.front().size() < 2)
|
||||
if (!(sequences.m_data.empty() || sequences.m_data.size() == 1))
|
||||
{
|
||||
continue;
|
||||
// Check that we do not exceed the minibatch size.
|
||||
size_t count = 0;
|
||||
for (const auto& sequence : sequences.m_data.front())
|
||||
{
|
||||
count += sequence->m_numberOfSamples;
|
||||
}
|
||||
BOOST_REQUIRE_LE(count, samplesToGet);
|
||||
}
|
||||
|
||||
// Check that we do not exceed the minibatch size.
|
||||
size_t count = 0;
|
||||
for (const auto& sequence : sequences.m_data.front())
|
||||
{
|
||||
count += sequence->m_numberOfSamples;
|
||||
}
|
||||
BOOST_CHECK_LE(count, samplesToGet);
|
||||
if (sequences.m_endOfEpoch)
|
||||
break;
|
||||
|
||||
}
|
||||
BOOST_REQUIRE(sweepCount == numSweeps);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(BlockRandomizerChaosMonkey)
|
||||
BOOST_AUTO_TEST_CASE(RandomizerChaosMonkey)
|
||||
{
|
||||
BlockRandomizerChaosMonkeyTest(false);
|
||||
BlockRandomizerChaosMonkeyTest(true);
|
||||
const int sequenceLength = 3;
|
||||
const int numChunks = 100;
|
||||
const int numSequencesPerChunk = 10;
|
||||
const int windowSize = 18;
|
||||
vector<float> data(numChunks * numSequencesPerChunk);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(numChunks, numSequencesPerChunk, data, sequenceLength);
|
||||
BlockRandomizer blockRandomizerNoPrefetch(0, windowSize, mockDeserializer, false, false);
|
||||
BlockRandomizer blockRandomizerWithPrefetch(0, windowSize, mockDeserializer, true, false);
|
||||
NoRandomizer norandomizer(mockDeserializer);
|
||||
|
||||
auto sweepSize = data.size() * sequenceLength;
|
||||
|
||||
RandomizerChaosMonkeyTest(blockRandomizerNoPrefetch, sweepSize, 42);
|
||||
RandomizerChaosMonkeyTest(blockRandomizerWithPrefetch, sweepSize, 43);
|
||||
RandomizerChaosMonkeyTest(norandomizer, sweepSize, 44);
|
||||
}
|
||||
|
||||
void BlockRandomizerOneEpochLegacyRandomizationTest(bool prefetch)
|
||||
|
@ -618,7 +762,8 @@ void BlockRandomizerOneEpochLegacyRandomizationTest(bool prefetch)
|
|||
auto randomizer = make_shared<BlockRandomizer>(0,
|
||||
SIZE_MAX,
|
||||
mockDeserializer,
|
||||
prefetch);
|
||||
prefetch,
|
||||
true);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = 1;
|
||||
|
@ -691,6 +836,48 @@ BOOST_AUTO_TEST_CASE(NoRandomizerOneEpoch)
|
|||
actual.begin(), actual.end());
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers)
|
||||
{
|
||||
size_t chunkSizeInSamples = 10000;
|
||||
size_t sweepNumberOfSamples = 500000;
|
||||
uint32_t maxSequenceLength = 300;
|
||||
size_t randomizationWindow = chunkSizeInSamples * 5;
|
||||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
auto blockRandomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
|
||||
auto noRandomizer = make_shared<NoRandomizer>(deserializer, false);
|
||||
|
||||
auto test = [](SequenceEnumeratorPtr r, size_t epochSize)
|
||||
{
|
||||
auto firstEpoch = ReadFullEpoch(r, epochSize, 0);
|
||||
auto firstCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(firstCursor, firstEpoch.size());
|
||||
|
||||
auto secondEpoch = ReadFullEpoch(r, epochSize, 1);
|
||||
auto secondCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(secondCursor - firstCursor, secondEpoch.size());
|
||||
|
||||
auto thirdEpoch = ReadFullEpoch(r, epochSize, 2);
|
||||
auto thirdCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(thirdCursor - secondCursor, thirdEpoch.size());
|
||||
|
||||
auto anotherSecondEpoch = ReadFullEpoch(r, epochSize, 1);
|
||||
auto anotherSecondCursor = r->GetCurrentSamplePosition();
|
||||
|
||||
BOOST_CHECK_EQUAL(anotherSecondCursor, secondCursor);
|
||||
};
|
||||
|
||||
// Inside sweep
|
||||
size_t epochSize = 50000;
|
||||
test(blockRandomizer, epochSize);
|
||||
test(noRandomizer, epochSize);
|
||||
|
||||
// Between sweeps
|
||||
epochSize = (size_t)(sweepNumberOfSamples / 1.5);
|
||||
test(blockRandomizer, epochSize);
|
||||
test(noRandomizer, epochSize);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(DefaultCorpusDescriptor)
|
||||
{
|
||||
const int seed = 13;
|
||||
|
@ -713,8 +900,8 @@ BOOST_AUTO_TEST_CASE(NumericCorpusDescriptor)
|
|||
CorpusDescriptor corpus(true);
|
||||
for (int i = 0; i < 10; ++i)
|
||||
{
|
||||
auto value = distr(rng);
|
||||
BOOST_CHECK_EQUAL(value, corpus.KeyToId(std::to_string(value)));
|
||||
auto value = distr(rng);
|
||||
BOOST_CHECK_EQUAL(value, corpus.KeyToId(std::to_string(value)));
|
||||
}
|
||||
BOOST_CHECK_EXCEPTION(
|
||||
corpus.KeyToId("not a number"),
|
||||
|
|
|
@ -166,7 +166,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
|
|||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
trainer->TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo].m_data }, { labelsVar, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo] }, { labelsVar, minibatchData[labelStreamInfo] } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -106,10 +106,6 @@ void TestRMSPropLearner(size_t numParameters, size_t numMinibatches, const Devic
|
|||
|
||||
void TestTrainingParametersSchedule()
|
||||
{
|
||||
VerifyException([]() {
|
||||
LearningRatePerMinibatchSchedule({ 3.0, 2.0, 1.0 }, LearningRateSchedule::EntireSweep);
|
||||
}, "Was able to create not-yet-implemented sweep-based schedule.");
|
||||
|
||||
LearningRatePerSampleSchedule schedule1 = 0.5;
|
||||
assert(schedule1.Unit() == LearningRateSchedule::UnitType::Sample);
|
||||
assert(schedule1[0] == 0.5);
|
||||
|
@ -229,11 +225,76 @@ void TestTrainingParametersSchedule()
|
|||
}
|
||||
|
||||
|
||||
void TestSweepBasedSchedule()
|
||||
{
|
||||
DeviceDescriptor device = DeviceDescriptor::CPUDevice();
|
||||
auto schedule = LearningRatePerSampleSchedule({ 1, 2, 3, 4, 5 }, LearningRateSchedule::FullDataSweep);
|
||||
|
||||
auto learner1 = SGDLearner({}, schedule);
|
||||
assert(1 == learner1->LearningRate());
|
||||
|
||||
|
||||
for (auto i : {2, 3, 4, 5 })
|
||||
{
|
||||
std::unordered_map<Parameter, NDArrayViewPtr> gradients {};
|
||||
learner1->Update(gradients, 1, true);
|
||||
assert(i == learner1->LearningRate());
|
||||
}
|
||||
|
||||
const size_t inputDim = 2;
|
||||
const size_t numOutputClasses = 2;
|
||||
auto minibatchSource = TextFormatMinibatchSource(L"SimpleDataTest_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
|
||||
|
||||
auto sweepSize = 603; // == wc -l SimpleDataTest_cntk_text.txt
|
||||
auto minibatchSize = 400;
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(L"features");
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(L"labels");
|
||||
|
||||
auto input = InputVariable({ inputDim }, DataType::Float, L"features");
|
||||
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");
|
||||
|
||||
|
||||
auto classifierOutput = FullyConnectedLinearLayer(input, numOutputClasses, device);
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
auto learner2 = SGDLearner(classifierOutput->Parameters(), schedule);
|
||||
auto trainer = CreateTrainer(classifierOutput, trainingLoss, prediction, { learner2 });
|
||||
|
||||
|
||||
for (auto i = 0; i <= 4000; i+= minibatchSize)
|
||||
{
|
||||
auto sweepIndex1 = i / sweepSize;
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
|
||||
if (minibatchData[featureStreamInfo].sweepEnd != minibatchData[labelStreamInfo].sweepEnd) {
|
||||
ReportFailure("TestSweepBasedSchedule failed: "
|
||||
"different streams have different end of sweep flag values.");
|
||||
}
|
||||
|
||||
auto sweepIndex2 = (i + minibatchSize) / sweepSize;
|
||||
|
||||
if ((sweepIndex1 != sweepIndex2) != minibatchData[labelStreamInfo].sweepEnd) {
|
||||
ReportFailure("TestSweepBasedSchedule failed: "
|
||||
"end of sweep flag value is different from expected.");
|
||||
}
|
||||
|
||||
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto expectedLR = std::min((sweepIndex2 + 1), 5);
|
||||
|
||||
if (expectedLR != learner2->LearningRate()) {
|
||||
ReportFailure("TestSweepBasedSchedule failed: "
|
||||
"learning rate value is different from expected.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void LearnerTests()
|
||||
{
|
||||
fprintf(stderr, "\nLearnerTests..\n");
|
||||
|
||||
TestTrainingParametersSchedule();
|
||||
TestSweepBasedSchedule();
|
||||
|
||||
vector<DeviceDescriptor> devices{DeviceDescriptor::CPUDevice()};
|
||||
|
||||
|
|
|
@ -159,11 +159,11 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
|
|||
auto minibatchData = minibatchSource->GetNextMinibatch(0, minibatchSize, 1, 0);
|
||||
auto minibatchData2 = minibatchSource2->GetNextMinibatch(0, minibatchSize, 1, 0);
|
||||
|
||||
if (minibatchData[featureStreamInfo].m_numSamples != minibatchData2[featureStreamInfo].m_numSamples)
|
||||
if (minibatchData[featureStreamInfo].numberOfSamples != minibatchData2[featureStreamInfo].numberOfSamples)
|
||||
ReportFailure("Data does not match, reads are not deterministic!!!");
|
||||
|
||||
// Because they are supposed to read the same data - adding it only once.
|
||||
totalSamples += minibatchData[featureStreamInfo].m_numSamples;
|
||||
totalSamples += minibatchData[featureStreamInfo].numberOfSamples;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -179,27 +179,27 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
|
|||
// Update the counter
|
||||
size_t accumulative = 0;
|
||||
if (!minibatchData.empty())
|
||||
accumulative += minibatchData[featureStreamInfo].m_numSamples;
|
||||
accumulative += minibatchData[featureStreamInfo].numberOfSamples;
|
||||
if (!minibatchData2.empty())
|
||||
accumulative += minibatchData2[featureStreamInfo].m_numSamples;
|
||||
accumulative += minibatchData2[featureStreamInfo].numberOfSamples;
|
||||
|
||||
totalSamples += accumulative;
|
||||
|
||||
if (expectNoData) // second worker does not have any data.
|
||||
{
|
||||
if (minibatchData[featureStreamInfo].m_numSamples != minibatchSize/2 && totalSamples != numberOfSamplesInSweep)
|
||||
if (minibatchData[featureStreamInfo].numberOfSamples != minibatchSize/2 && totalSamples != numberOfSamplesInSweep)
|
||||
ReportFailure("TestMinibatchSourceWarmStart failed because data did not match."
|
||||
"Expected minibatch size '%d', acutal '%d'. Total number of sample '%d', sweep '%d'.",
|
||||
(int)minibatchSize,
|
||||
(int)minibatchData[featureStreamInfo].m_numSamples,
|
||||
(int)minibatchData[featureStreamInfo].numberOfSamples,
|
||||
(int)totalSamples,
|
||||
(int)numberOfSamplesInSweep);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (accumulative != minibatchSize &&
|
||||
minibatchData[featureStreamInfo].m_numSamples != minibatchSize / 2 &&
|
||||
minibatchData2[featureStreamInfo].m_numSamples != minibatchSize / 2 &&
|
||||
minibatchData[featureStreamInfo].numberOfSamples != minibatchSize / 2 &&
|
||||
minibatchData2[featureStreamInfo].numberOfSamples != minibatchSize / 2 &&
|
||||
totalSamples != numberOfSamplesInSweep)
|
||||
ReportFailure("TestMinibatchSourceWarmStart failed because data did not match."
|
||||
"Expected minibatch size '%d', acutal '%d'. Total number of sample '%d', sweep '%d'.",
|
||||
|
@ -217,8 +217,81 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
|
|||
(int)totalSamples);
|
||||
}
|
||||
|
||||
void TestEndOfSweepFlag(size_t maxSamples, size_t mbSize, bool randomize)
|
||||
{
|
||||
const size_t sweepSize = 603;
|
||||
auto ctfInput = L"SimpleDataTest_cntk_text.txt";
|
||||
std::vector<StreamConfiguration> streamConfig { { L"features", 2 } };
|
||||
auto cpuDevice = DeviceDescriptor::CPUDevice();
|
||||
auto src = TextFormatMinibatchSource(ctfInput, streamConfig, maxSamples, randomize);
|
||||
|
||||
maxSamples = (maxSamples == MinibatchSource::FullDataSweep) ? sweepSize : maxSamples;
|
||||
|
||||
bool reachedEndOfEpoch = false;
|
||||
size_t sampleCount = 0;
|
||||
|
||||
while (sampleCount < maxSamples)
|
||||
{
|
||||
auto& dataMap = src->GetNextMinibatch(mbSize, cpuDevice);
|
||||
|
||||
if (dataMap.size() != streamConfig.size())
|
||||
{
|
||||
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: "
|
||||
"unexpected number of streams in the minibatch (%zu).", dataMap.size());
|
||||
}
|
||||
|
||||
for (auto& streamData : dataMap)
|
||||
{
|
||||
auto numSamplesInMinibatch = streamData.second.numberOfSamples;
|
||||
bool expectedEndOfSweep = ((sampleCount + numSamplesInMinibatch) % sweepSize) == 0;
|
||||
expectedEndOfSweep |= ((sampleCount) / sweepSize) < ((sampleCount + numSamplesInMinibatch) / sweepSize);
|
||||
|
||||
|
||||
reachedEndOfEpoch = (sampleCount + mbSize >= maxSamples);
|
||||
size_t expectedNumSamples = reachedEndOfEpoch ? (maxSamples - sampleCount) : mbSize;
|
||||
|
||||
|
||||
if (streamData.second.sweepEnd != expectedEndOfSweep)
|
||||
{
|
||||
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: end of sweep flag is not set.");
|
||||
}
|
||||
if (streamData.second.numberOfSamples != expectedNumSamples)
|
||||
{
|
||||
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: "
|
||||
"unexpected number of samples in the minibatch (%zu).", streamData.second.numberOfSamples);
|
||||
}
|
||||
if (streamData.second.numberOfSequences != expectedNumSamples)
|
||||
{
|
||||
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: "
|
||||
"unexpected number of sequences in the minibatch (%zu).", streamData.second.numberOfSequences);
|
||||
}
|
||||
}
|
||||
|
||||
sampleCount += mbSize;
|
||||
}
|
||||
|
||||
auto& emptyDataMap = src->GetNextMinibatch(mbSize, cpuDevice);
|
||||
assert(emptyDataMap.empty());
|
||||
}
|
||||
|
||||
void TestThatEndOfSweepFlagIsSetCorrectly()
|
||||
{
|
||||
for (auto randomize : { false, true })
|
||||
{
|
||||
TestEndOfSweepFlag(MinibatchSource::FullDataSweep, 603, randomize);
|
||||
TestEndOfSweepFlag(MinibatchSource::FullDataSweep, 1000, randomize);
|
||||
TestEndOfSweepFlag(MinibatchSource::FullDataSweep, 100, randomize);
|
||||
|
||||
TestEndOfSweepFlag(100, 30, randomize);
|
||||
TestEndOfSweepFlag(2000, 500, randomize);
|
||||
TestEndOfSweepFlag(2412, 301, randomize);
|
||||
}
|
||||
}
|
||||
|
||||
void MinibatchSourceTests()
|
||||
{
|
||||
TestThatEndOfSweepFlagIsSetCorrectly();
|
||||
|
||||
// Test no-randomize minibatch source with small data chunks
|
||||
TestMinibatchSourceWarmStart(64, 128, false, 1024);
|
||||
TestMinibatchSourceWarmStart(64, 0, false, 1024);
|
||||
|
|
|
@ -209,7 +209,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
if (minibatchData.empty())
|
||||
break;
|
||||
|
||||
trainer->TrainMinibatch({ { rawInput, minibatchData[rawInputStreamInfo].m_data }, { rawLabels, minibatchData[rawLabelsStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { rawInput, minibatchData[rawInputStreamInfo] }, { rawLabels, minibatchData[rawLabelsStreamInfo] } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
|
||||
if ((i + 1) == numMinibatchesToCheckpointAfter)
|
||||
|
@ -222,7 +222,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
if ((i % decodingFrequency) == 0)
|
||||
{
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { decodingFunction, nullptr }};
|
||||
decodingFunction->Forward({ { decodingFunction->Arguments()[0], minibatchData[rawInputStreamInfo].m_data }, { decodingFunction->Arguments()[1], minibatchData[rawLabelsStreamInfo].m_data } },
|
||||
decodingFunction->Forward({ { decodingFunction->Arguments()[0], minibatchData[rawInputStreamInfo].data }, { decodingFunction->Arguments()[1], minibatchData[rawLabelsStreamInfo].data } },
|
||||
outputs,
|
||||
device);
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool useSparseLa
|
|||
if (minibatchData.empty())
|
||||
break;
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
@ -87,37 +87,37 @@ void TestLearningRateControl(const DeviceDescriptor& device)
|
|||
|
||||
const size_t minibatchSize = 200;
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].numberOfSamples;
|
||||
|
||||
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
|
||||
auto trainer = CreateTrainer(classifierOutput, trainingLoss, prediction, { learner });
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
const wchar_t* modelFile = L"seq2seq.model";
|
||||
trainer->SaveCheckpoint(modelFile);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto MB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto MB3Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->RestoreFromCheckpoint(modelFile);
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto postRestoreMB3Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
|
@ -128,13 +128,13 @@ void TestLearningRateControl(const DeviceDescriptor& device)
|
|||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->SaveCheckpoint(modelFile);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
postRestoreMB3Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
|
@ -143,13 +143,13 @@ void TestLearningRateControl(const DeviceDescriptor& device)
|
|||
trainer->RestoreFromCheckpoint(modelFile);
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
postRestoreMB3Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
|
|
|
@ -434,7 +434,7 @@ void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const
|
|||
|
||||
Dictionary model = classifierOutput1->Serialize();
|
||||
|
||||
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
auto classifierOutput2 = Function::Deserialize(model, device);
|
||||
|
||||
|
@ -458,8 +458,8 @@ void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const
|
|||
|
||||
for (int j = 0; j < 3; ++j)
|
||||
{
|
||||
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer2->TrainMinibatch({ { classifierOutput3->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
trainer2->TrainMinibatch({ { classifierOutput3->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
double mbLoss1 = trainer1->PreviousMinibatchLossAverage();
|
||||
double mbLoss2 = trainer2->PreviousMinibatchLossAverage();
|
||||
|
@ -503,7 +503,7 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
|
|||
|
||||
const size_t minibatchSize = 50;
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].numberOfSamples;
|
||||
|
||||
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize);
|
||||
|
@ -522,14 +522,14 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
|
|||
throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original.");
|
||||
}
|
||||
|
||||
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
if (AreEqual(function1, function2))
|
||||
{
|
||||
throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained.");
|
||||
}
|
||||
|
||||
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
if (!AreEqual(function1, function2))
|
||||
{
|
||||
|
@ -548,8 +548,8 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
|
|||
|
||||
for (int j = 0; j < 3; ++j)
|
||||
{
|
||||
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
double mbLoss1 = trainer1->PreviousMinibatchLossAverage();
|
||||
double mbLoss2 = trainer2->PreviousMinibatchLossAverage();
|
||||
|
@ -638,36 +638,36 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
|
|||
|
||||
const size_t minibatchSize = 50;
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].numberOfSamples;
|
||||
|
||||
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
|
||||
auto trainer = CreateTrainer(classifierOutput, trainingLoss, prediction, { learner });
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
const wchar_t* modelFile = L"seq2seq.legacy.model";
|
||||
Internal::SaveAsLegacyModel(classifierOutput, modelFile);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto MB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
classifierOutput->RestoreModel(modelFile);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
auto postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
classifierOutput->RestoreModel(modelFile);
|
||||
Internal::SaveAsLegacyModel(classifierOutput, modelFile);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
classifierOutput->RestoreModel(modelFile);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
|
@ -684,7 +684,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
|
|||
{
|
||||
trainer->SaveCheckpoint(L"trainer.checkpoint" + std::to_wstring(i));
|
||||
Internal::SaveAsLegacyModel(classifierOutput, modelFile + std::to_wstring(i));
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
expectedLoss.push_back(trainer->PreviousMinibatchLossAverage());
|
||||
}
|
||||
|
||||
|
@ -692,7 +692,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
|
|||
{
|
||||
trainer->RestoreFromCheckpoint(L"trainer.checkpoint" + std::to_wstring(i));
|
||||
classifierOutput->RestoreModel(modelFile + std::to_wstring(i));
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
double loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
|
||||
}
|
||||
|
@ -767,20 +767,20 @@ void TestCheckpointingWithStatefulNodes(const DeviceDescriptor& device)
|
|||
auto featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
|
||||
vector<double> expectedLoss;
|
||||
for (int i = 0; i < epochSize / minibatchSize; i++)
|
||||
{
|
||||
trainer->SaveCheckpoint(L"stateful_nodes.model" + std::to_wstring(i));
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
expectedLoss.push_back(trainer->PreviousMinibatchLossAverage());
|
||||
}
|
||||
|
||||
for (int i = 0; i < epochSize / minibatchSize; i++)
|
||||
{
|
||||
trainer->RestoreFromCheckpoint(L"stateful_nodes.model" + std::to_wstring(i));
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
double loss = trainer->PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
|
|||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
|
||||
if ((i % trainingCheckpointFrequency) == (trainingCheckpointFrequency - 1))
|
||||
|
@ -128,7 +128,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
|
|||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,11 +125,11 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
|
|||
break;
|
||||
|
||||
// Make sure our truncation length setting was honored
|
||||
auto actualMaxSequenceLength = minibatchData[featureStreamInfo].m_data->Shape()[featureStreamInfo.m_sampleLayout.Rank()];
|
||||
auto actualMaxSequenceLength = minibatchData[featureStreamInfo].data->Shape()[featureStreamInfo.m_sampleLayout.Rank()];
|
||||
if (actualMaxSequenceLength != truncationLength)
|
||||
ReportFailure("Actual max sequence length (%d) in minibatch data does not equal specified truncation length (%d)", (int)actualMaxSequenceLength, (int)truncationLength);
|
||||
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,15 @@
|
|||
|
||||
%rename(momentum_as_time_constant_schedule) CNTK::MomentumAsTimeConstantSchedule;
|
||||
|
||||
// renaming overloads for TrainMinibatch and TestMinibatch that take a map
|
||||
// of Variables and MinibatchData as their first parameter. If this is not done,
|
||||
// the overloads that are legal in C++ will be shadowed and ignored by SWIG.
|
||||
// The naming here is somewhat cumbersome, but it's only intended for internal
|
||||
// consumption in proxy objects.
|
||||
%rename(train_minibatch_overload_for_minibatchdata) CNTK::Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>&, const DeviceDescriptor& = DeviceDescriptor::UseDefaultDevice());
|
||||
%rename(train_minibatch_overload_for_minibatchdata) CNTK::Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>&, std::unordered_map<Variable, ValuePtr>&, const DeviceDescriptor& = DeviceDescriptor::UseDefaultDevice());
|
||||
%rename(test_minibatch_overload_for_minibatchdata) CNTK::Trainer::TestMinibatch(const std::unordered_map<Variable, MinibatchData>&, const DeviceDescriptor& = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
%rename(l1_regularization_weight) CNTK::AdditionalLearningOptions::l1RegularizationWeight;
|
||||
%rename(l2_regularization_weight) CNTK::AdditionalLearningOptions::l2RegularizationWeight;
|
||||
%rename(ndcg_at_1) CNTK::NDCGAt1;
|
||||
|
@ -1087,6 +1096,7 @@ public:
|
|||
%unordered_map_conversion(CNTK::Parameter, const CNTK::NDArrayViewPtr, SWIGTYPE_p_CNTK__Parameter, SWIGTYPE_p_std__shared_ptrT_CNTK__NDArrayView_t)
|
||||
%unordered_map_conversion(CNTK::Parameter, CNTK::NDArrayViewPtr, SWIGTYPE_p_CNTK__Parameter, SWIGTYPE_p_std__shared_ptrT_CNTK__NDArrayView_t)
|
||||
%unordered_map_conversion(CNTK::Variable, CNTK::StreamInformation, SWIGTYPE_p_CNTK__Variable, SWIGTYPE_p_CNTK__StreamInformation)
|
||||
%unordered_map_conversion(CNTK::Variable, CNTK::MinibatchData, SWIGTYPE_p_CNTK__Variable, SWIGTYPE_p_CNTK__MinibatchData)
|
||||
|
||||
%unordered_map_ref_conversion(CNTK::StreamInformation, SWIGTYPE_p_CNTK__StreamInformation, CNTK::MinibatchData, SWIGTYPE_p_CNTK__MinibatchData);
|
||||
%unordered_map_ref_conversion(CNTK::Parameter, SWIGTYPE_p_CNTK__Parameter, CNTK::NDArrayViewPtr, SWIGTYPE_p_std__shared_ptrT_CNTK__NDArrayView);
|
||||
|
|
|
@ -27,28 +27,28 @@ class MinibatchData(cntk_py.MinibatchData, ArrayMixin):
|
|||
'''
|
||||
The number of sequences in this minibatch
|
||||
'''
|
||||
return self.m_num_sequences
|
||||
return self.number_of_sequences
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
'''
|
||||
The number of samples in this minibatch
|
||||
'''
|
||||
return self.m_num_samples
|
||||
return self.number_of_samples
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
'''
|
||||
The value of the minibatch as a NumPy array.
|
||||
'''
|
||||
return value_to_seq(self.m_data)
|
||||
return value_to_seq(self.data)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
'''
|
||||
The shape of the data in this minibatch as tuple.
|
||||
'''
|
||||
return self.m_data.shape().dimensions()
|
||||
return self.data.shape().dimensions()
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
|
@ -57,14 +57,23 @@ class MinibatchData(cntk_py.MinibatchData, ArrayMixin):
|
|||
sequence, `1` marks a sequence element as valid, and `0` marks it as
|
||||
invalid.
|
||||
'''
|
||||
return self.m_data.mask().to_ndarray()
|
||||
return self.data.mask().to_ndarray()
|
||||
|
||||
@property
|
||||
def end_of_sweep(self):
|
||||
'''
|
||||
Indicates whether the data in this minibatch is comes from a sweep end
|
||||
or crosses a sweep boundary (and as a result includes data from
|
||||
different sweeps).
|
||||
'''
|
||||
return self.sweep_end
|
||||
|
||||
@property
|
||||
def is_sparse(self):
|
||||
'''
|
||||
Whether the data in this minibatch is sparse.
|
||||
'''
|
||||
return self.m_data.is_sparse()
|
||||
return self.data.is_sparse()
|
||||
|
||||
def __len__(self):
|
||||
return self.num_sequences
|
||||
|
|
|
@ -9,15 +9,13 @@ import os
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from cntk.io import _is_tensor, sequence_to_cntk_text_format
|
||||
from cntk.io import *
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
AA = np.asarray
|
||||
|
||||
def test_text_format(tmpdir):
|
||||
from cntk.io import CTFDeserializer, MinibatchSource, StreamDef, StreamDefs
|
||||
|
||||
mbdata = r'''0 |x 560:1 |y 1 0 0 0 0
|
||||
0 |x 0:1
|
||||
0 |x 0:1
|
||||
|
@ -48,6 +46,9 @@ def test_text_format(tmpdir):
|
|||
features = mb[features_si]
|
||||
# 2 samples, max seq len 4, 1000 dim
|
||||
assert features.shape == (2, 4, input_dim)
|
||||
assert features.end_of_sweep
|
||||
assert features.num_sequences == 2
|
||||
assert features.num_samples == 7
|
||||
assert features.is_sparse
|
||||
# TODO features is sparse and cannot be accessed right now:
|
||||
# *** RuntimeError: DataBuffer/WritableDataBuffer methods can only be called for NDArrayiew objects with dense storage format
|
||||
|
@ -58,6 +59,9 @@ def test_text_format(tmpdir):
|
|||
labels = mb[labels_si]
|
||||
# 2 samples, max seq len 1, 5 dim
|
||||
assert labels.shape == (2, 1, num_output_classes)
|
||||
assert labels.end_of_sweep
|
||||
assert labels.num_sequences == 2
|
||||
assert labels.num_samples == 2
|
||||
assert not labels.is_sparse
|
||||
|
||||
label_data = np.asarray(labels)
|
||||
|
@ -67,8 +71,16 @@ def test_text_format(tmpdir):
|
|||
[[ 0., 1., 0., 0., 0.]]
|
||||
]))
|
||||
|
||||
mb = mb_source.next_minibatch(1)
|
||||
features = mb[features_si]
|
||||
labels = mb[labels_si]
|
||||
|
||||
assert not features.end_of_sweep
|
||||
assert not labels.end_of_sweep
|
||||
assert features.num_samples < 7
|
||||
assert labels.num_samples == 1
|
||||
|
||||
def test_image():
|
||||
from cntk.io import ReaderConfig, ImageDeserializer
|
||||
map_file = "input.txt"
|
||||
mean_file = "mean.txt"
|
||||
epoch_size = 150
|
||||
|
@ -153,7 +165,7 @@ def test_image():
|
|||
assert set(sis.keys()) == { feature_name, label_name }
|
||||
'''
|
||||
|
||||
def test_minibatch(tmpdir):
|
||||
def test_full_sweep_minibatch(tmpdir):
|
||||
|
||||
mbdata = r'''0 |S0 0 |S1 0
|
||||
0 |S0 1 |S1 1
|
||||
|
@ -168,10 +180,10 @@ def test_minibatch(tmpdir):
|
|||
with open(tmpfile, 'w') as f:
|
||||
f.write(mbdata)
|
||||
|
||||
from cntk.io import CTFDeserializer, MinibatchSource, StreamDef, StreamDefs
|
||||
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=1),
|
||||
labels = StreamDef(field='S1', shape=1))))
|
||||
labels = StreamDef(field='S1', shape=1))),
|
||||
randomize=False, epoch_size=FULL_DATA_SWEEP)
|
||||
|
||||
features_si = mb_source.stream_info('features')
|
||||
labels_si = mb_source.stream_info('labels')
|
||||
|
@ -181,6 +193,7 @@ def test_minibatch(tmpdir):
|
|||
assert mb[labels_si].num_sequences == 2
|
||||
|
||||
features = mb[features_si]
|
||||
assert features.end_of_sweep
|
||||
assert len(features.value) == 2
|
||||
expected_features = \
|
||||
[
|
||||
|
@ -196,6 +209,7 @@ def test_minibatch(tmpdir):
|
|||
[2, 1, 1, 0]])
|
||||
|
||||
labels = mb[labels_si]
|
||||
assert labels.end_of_sweep
|
||||
assert len(labels.value) == 2
|
||||
expected_labels = \
|
||||
[
|
||||
|
@ -209,6 +223,46 @@ def test_minibatch(tmpdir):
|
|||
[[2, 1, 1],
|
||||
[2, 1, 0]])
|
||||
|
||||
def test_large_minibatch(tmpdir):
|
||||
|
||||
mbdata = r'''0 |S0 0 |S1 0
|
||||
0 |S0 1 |S1 1
|
||||
0 |S0 2
|
||||
0 |S0 3 |S1 3
|
||||
0 |S0 4
|
||||
0 |S0 5 |S1 1
|
||||
0 |S0 6 |S1 2
|
||||
'''
|
||||
|
||||
tmpfile = str(tmpdir/'mbtest.txt')
|
||||
with open(tmpfile, 'w') as f:
|
||||
f.write(mbdata)
|
||||
|
||||
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=1),
|
||||
labels = StreamDef(field='S1', shape=1))),
|
||||
randomize=False)
|
||||
|
||||
features_si = mb_source.stream_info('features')
|
||||
labels_si = mb_source.stream_info('labels')
|
||||
|
||||
mb = mb_source.next_minibatch(1000)
|
||||
features = mb[features_si]
|
||||
labels = mb[labels_si]
|
||||
|
||||
# Actually, the minibatch spans over multiple sweeps,
|
||||
# not sure if this is an artificial situation, but
|
||||
# maybe instead of a boolean flag we should indicate
|
||||
# the largest sweep index the data was taken from.
|
||||
assert features.end_of_sweep
|
||||
assert labels.end_of_sweep
|
||||
|
||||
assert features.num_samples == 1000 - 1000 % 7
|
||||
assert labels.num_samples == 5 * (1000 // 7)
|
||||
|
||||
assert mb[features_si].num_sequences == (1000 // 7)
|
||||
assert mb[labels_si].num_sequences == (1000 // 7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("idx, alias_tensor_map, expected", [
|
||||
(0, {'A': [object()]}, ValueError),
|
||||
|
@ -250,4 +304,5 @@ def test_sequence_conversion_dense(idx, alias_tensor_map, expected):
|
|||
([AA([1, 2]), AA([])], False),
|
||||
])
|
||||
def test_is_tensor(data, expected):
|
||||
from cntk.io import _is_tensor
|
||||
assert _is_tensor(data) == expected
|
||||
|
|
|
@ -130,7 +130,7 @@ class Learner(cntk_py.Learner):
|
|||
return super(Learner, self).learning_rate()
|
||||
|
||||
@typemap
|
||||
def training_parameter_schedule(schedule, unit, epoch_size=1):
|
||||
def training_parameter_schedule(schedule, unit, epoch_size=None):
|
||||
'''
|
||||
Create a training parameter schedule containing either per-sample (default)
|
||||
or per-minibatch values.
|
||||
|
@ -160,8 +160,13 @@ def training_parameter_schedule(schedule, unit, epoch_size=1):
|
|||
unit (:class:`UnitType`): one of two
|
||||
* ``sample``: the returned schedule contains per-sample values
|
||||
* ``minibatch``: the returned schedule contains per-minibatch values.
|
||||
epoch_size (int): number of samples as a scheduling unit. Parameters in
|
||||
the schedule change their values every ``epoch_size`` samples.
|
||||
epoch_size (optional, int): number of samples as a scheduling unit.
|
||||
Parameters in the schedule change their values every ``epoch_size``
|
||||
samples. If no ``epoch_size`` is provided, this parameter is substituted
|
||||
by the size of the full data sweep, in which case the scheduling unit is
|
||||
the entire data sweep (as indicated by the MinibatchSource) and parameters
|
||||
change their values on the sweep-by-sweep basis specified by the
|
||||
``schedule``.
|
||||
|
||||
Returns:
|
||||
training parameter schedule
|
||||
|
@ -177,7 +182,7 @@ def training_parameter_schedule(schedule, unit, epoch_size=1):
|
|||
return schedule
|
||||
|
||||
if isinstance(schedule, (int, float)):
|
||||
if epoch_size != 1:
|
||||
if epoch_size is not None:
|
||||
raise ValueError('when providing the schedule as a number,'
|
||||
' epoch_size is ignored')
|
||||
if UnitType(unit) is UnitType.sample:
|
||||
|
@ -185,16 +190,18 @@ def training_parameter_schedule(schedule, unit, epoch_size=1):
|
|||
else:
|
||||
return cntk_py.training_parameter_per_minibatch_schedule(schedule)
|
||||
|
||||
args = [schedule] if epoch_size is None else [schedule, epoch_size]
|
||||
|
||||
if isinstance(schedule, list):
|
||||
if UnitType(unit) is UnitType.sample:
|
||||
return cntk_py.training_parameter_per_sample_schedule(schedule, epoch_size)
|
||||
return cntk_py.training_parameter_per_sample_schedule(*args)
|
||||
else:
|
||||
return cntk_py.training_parameter_per_minibatch_schedule(schedule, epoch_size)
|
||||
return cntk_py.training_parameter_per_minibatch_schedule(*args)
|
||||
|
||||
raise ValueError('schedule must be either a float or a list, not %s'%type(schedule))
|
||||
|
||||
@typemap
|
||||
def learning_rate_schedule(lr, unit, epoch_size=1):
|
||||
def learning_rate_schedule(lr, unit, epoch_size=None):
|
||||
'''
|
||||
Create a learning rate schedule (using the same semantics as
|
||||
:func:`training_parameter_schedule`).
|
||||
|
@ -216,7 +223,7 @@ def learning_rate_schedule(lr, unit, epoch_size=1):
|
|||
return training_parameter_schedule(lr, unit, epoch_size)
|
||||
|
||||
@typemap
|
||||
def momentum_schedule(momentum, epoch_size=1):
|
||||
def momentum_schedule(momentum, epoch_size=None):
|
||||
'''
|
||||
Create a per-minibatch momentum schedule (using the same semantics as
|
||||
:func:`training_parameter_schedule` with the `unit=UnitType.minibatch`).
|
||||
|
@ -253,7 +260,7 @@ def momentum_schedule(momentum, epoch_size=1):
|
|||
return training_parameter_schedule(momentum, UnitType.minibatch, epoch_size)
|
||||
|
||||
@typemap
|
||||
def momentum_as_time_constant_schedule(momentum, epoch_size=1):
|
||||
def momentum_as_time_constant_schedule(momentum, epoch_size=None):
|
||||
'''
|
||||
Create a momentum schedule in a minibatch-size agnostic way
|
||||
(using the same semantics as :func:`training_parameter_schedule`
|
||||
|
@ -288,9 +295,14 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
|
|||
return momentum
|
||||
|
||||
if isinstance(momentum, (int, float)):
|
||||
if epoch_size is not None:
|
||||
raise ValueError('when providing the schedule as a number,'
|
||||
' epoch_size is ignored')
|
||||
return cntk_py.momentum_as_time_constant_schedule(momentum)
|
||||
|
||||
if isinstance(momentum, list):
|
||||
return cntk_py.momentum_as_time_constant_schedule(momentum, epoch_size)
|
||||
args = [momentum] if epoch_size is None else [momentum, epoch_size]
|
||||
return cntk_py.momentum_as_time_constant_schedule(*args)
|
||||
|
||||
raise ValueError('momentum must be either a float or a list, not %s'%type(momentum))
|
||||
|
||||
|
|
|
@ -166,15 +166,16 @@ class ArrayMixin(object):
|
|||
@property
|
||||
def __array_interface__(self):
|
||||
try:
|
||||
# This first check is for a Value object. Trying with self.to_ndarray first would lead to
|
||||
# a infinite recursion, since Value has a to_ndarray method
|
||||
np_array = self.data().to_ndarray()
|
||||
# This checks for a MinibatchData object.
|
||||
np_array = self.value
|
||||
except AttributeError:
|
||||
try:
|
||||
np_array = self.to_ndarray()
|
||||
# This checks for a Value object. Trying with self.to_ndarray first would lead to
|
||||
# a infinite recursion, since Value has a to_ndarray method
|
||||
np_array = self.data().to_ndarray()
|
||||
except AttributeError:
|
||||
try:
|
||||
np_array = self.value
|
||||
np_array = self.to_ndarray()
|
||||
except AttributeError:
|
||||
# Ideally an exception would be raised here, but getattr would swallow it
|
||||
# so we return None
|
||||
|
|
|
@ -96,7 +96,7 @@ def test_learner_update():
|
|||
w = parameter(shape=(1,), init=w_init)
|
||||
res = i * w
|
||||
|
||||
learner = sgd(res.parameters, lr=learning_rate_schedule([0.1]*50 + [0.2]*50, UnitType.sample))
|
||||
learner = sgd(res.parameters, lr=learning_rate_schedule([0.1]*50 + [0.2]*50, UnitType.sample, 1))
|
||||
assert learner.learning_rate() == 0.1
|
||||
x = learner.update({w: np.asarray([[2.]], dtype=np.float32)}, 100)
|
||||
assert learner.learning_rate() == 0.2
|
||||
|
@ -110,3 +110,66 @@ def test_training_parameter_schedule():
|
|||
training_parameter_schedule(0.01, unit='not_supported')
|
||||
with pytest.raises(ValueError):
|
||||
training_parameter_schedule(0.01, unit=5)
|
||||
|
||||
def test_sweep_based_schedule(tmpdir, device_id):
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
|
||||
from .. import cross_entropy_with_softmax, classification_error, plus, reduce_sum
|
||||
from ..trainer import Trainer
|
||||
|
||||
input_dim = 69
|
||||
|
||||
ctf_data = '''\
|
||||
0 |S0 3:1 |S1 3:1 |# <s>
|
||||
0 |S0 4:1 |# A |S1 32:1 |# ~AH
|
||||
0 |S0 5:1 |# B |S1 36:1 |# ~B
|
||||
0 |S0 4:1 |# A |S1 31:1 |# ~AE
|
||||
0 |S0 7:1 |# D |S1 38:1 |# ~D
|
||||
0 |S0 12:1 |# I |S1 47:1 |# ~IY
|
||||
0 |S0 1:1 |# </s> |S1 1:1 |# </s>
|
||||
2 |S0 60:1 |# <s> |S1 3:1 |# <s>
|
||||
2 |S0 61:1 |# A |S1 32:1 |# ~AH
|
||||
'''
|
||||
ctf_file = str(tmpdir/'2seqtest.txt')
|
||||
with open(ctf_file, 'w') as f:
|
||||
f.write(ctf_data)
|
||||
|
||||
mbs = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
|
||||
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
|
||||
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
|
||||
)), randomize=False)
|
||||
|
||||
in1 = input_variable(shape=(input_dim,))
|
||||
labels = input_variable(shape=(input_dim,))
|
||||
p = parameter(shape=(input_dim,), init=10)
|
||||
z = plus(in1, reduce_sum(p), name='z')
|
||||
ce = cross_entropy_with_softmax(z, labels)
|
||||
errs = classification_error(z, labels)
|
||||
|
||||
lr_per_sample = learning_rate_schedule([0.3, 0.2, 0.1, 0.0], UnitType.sample)
|
||||
learner = sgd(z.parameters, lr_per_sample)
|
||||
trainer = Trainer(z, ce, errs, [learner])
|
||||
|
||||
input_map = {
|
||||
in1 : mbs.streams.features,
|
||||
labels : mbs.streams.labels
|
||||
}
|
||||
|
||||
# fetch minibatch (first sequence)
|
||||
data = mbs.next_minibatch(1, input_map=input_map)
|
||||
trainer.train_minibatch(data)
|
||||
assert learner.learning_rate() == 0.3
|
||||
|
||||
# fetch minibatch (second sequence, sweep ends at this point)
|
||||
data = mbs.next_minibatch(1, input_map=input_map)
|
||||
trainer.train_minibatch(data)
|
||||
assert learner.learning_rate() == 0.2
|
||||
|
||||
# fetch minibatch (both sequences -- entire sweep in one go)
|
||||
data = mbs.next_minibatch(9, input_map=input_map)
|
||||
trainer.train_minibatch(data)
|
||||
assert learner.learning_rate() == 0.1
|
||||
|
||||
# fetch minibatch (multiple sweeps)
|
||||
data = mbs.next_minibatch(30, input_map=input_map)
|
||||
trainer.train_minibatch(data, [z.output])
|
||||
assert learner.learning_rate() == 0.0
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from . import cntk_py
|
||||
from .device import use_default_device
|
||||
from .utils import sanitize_var_map, sanitize_function, typemap, value_to_seq
|
||||
from .io import _py_dict_to_cntk_dict
|
||||
from .io import _py_dict_to_cntk_dict, MinibatchData
|
||||
|
||||
__doc__= '''\
|
||||
A trainer encapsulates the overall training process and employs one or more
|
||||
|
@ -78,18 +78,36 @@ class Trainer(cntk_py.Trainer):
|
|||
device = use_default_device()
|
||||
|
||||
if arguments:
|
||||
arguments = sanitize_var_map(self.model.arguments, arguments)
|
||||
arguments = sanitize_var_map(self.model.arguments, arguments,
|
||||
extract_values_from_minibatch_data = False)
|
||||
|
||||
contains_minibatch_data = False
|
||||
if (len(arguments) > 0):
|
||||
value = next(iter(arguments.values()))
|
||||
contains_minibatch_data = isinstance(value, MinibatchData)
|
||||
|
||||
if outputs:
|
||||
output_map = {v: None for v in outputs}
|
||||
updated = super(Trainer, self).train_minibatch(arguments,
|
||||
|
||||
if contains_minibatch_data:
|
||||
updated = super(Trainer, self).train_minibatch_overload_for_minibatchdata(
|
||||
arguments, output_map, device)
|
||||
else:
|
||||
updated = super(Trainer, self).train_minibatch(arguments,
|
||||
output_map, device)
|
||||
|
||||
for k,v in output_map.items():
|
||||
output_map[k] = value_to_seq(v)
|
||||
|
||||
return updated, output_map
|
||||
else:
|
||||
updated = super(Trainer, self).train_minibatch(arguments, device)
|
||||
|
||||
if contains_minibatch_data:
|
||||
updated = super(Trainer, self).train_minibatch_overload_for_minibatchdata(
|
||||
arguments, device)
|
||||
else:
|
||||
updated = super(Trainer, self).train_minibatch(arguments,
|
||||
device)
|
||||
|
||||
return updated
|
||||
|
||||
|
|
|
@ -315,7 +315,7 @@ def sanitize_function(arg):
|
|||
|
||||
|
||||
def sanitize_var_map(op_arguments, arguments, precision=None,
|
||||
device=None):
|
||||
device=None, extract_values_from_minibatch_data=True):
|
||||
'''
|
||||
Sanitizes a dictionary of `Variable` s to input data such that it can be
|
||||
handed off to the evaluation methods
|
||||
|
@ -361,6 +361,12 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
|
|||
one of 'float' 'float32, 'double', 'float64', or None
|
||||
device (:class:`~cntk.device.DeviceDescriptor`, default None): device
|
||||
this value should be put on
|
||||
extract_values_from_minibatch_data (`bool`, defaults to `True`): specifies
|
||||
if :class:`~cntk.io.MinibatchData` instances in the arguments map are
|
||||
converted to the underlying value (:class:`Value`) instances (default),
|
||||
or if they should remain intact, as they contain additional meta
|
||||
information required by the Trainer (specifically, by the
|
||||
:meth:`~cntk.Trainer.train_minibatch` method).
|
||||
|
||||
Returns:
|
||||
`dict` that maps variables to sanitized batches
|
||||
|
@ -436,9 +442,10 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
|
|||
'sequence begin markers' % (sample_sizes, len(seq_starts)))
|
||||
|
||||
|
||||
if isinstance(batch, MinibatchData):
|
||||
batch = batch.m_data
|
||||
elif not isinstance(batch, cntk_py.Value):
|
||||
if isinstance(batch, MinibatchData) and extract_values_from_minibatch_data:
|
||||
batch = batch.data
|
||||
|
||||
if not (isinstance(batch, MinibatchData) or isinstance(batch, cntk_py.Value)):
|
||||
batch = sanitize_batch(var, batch, seq_starts, device)
|
||||
|
||||
var_map[var] = batch
|
||||
|
|
|
@ -19,7 +19,7 @@ def test_rnn_error(device_id):
|
|||
error, loss = train_sequence_classifier()
|
||||
|
||||
expected_error = 0.333333
|
||||
expected_loss = 1.060453
|
||||
expected_loss = 1.12
|
||||
|
||||
assert np.allclose(error, expected_error, atol=TOLERANCE_ABSOLUTE)
|
||||
assert np.allclose(loss, expected_loss, atol=TOLERANCE_ABSOLUTE)
|
||||
|
|
Загрузка…
Ссылка в новой задаче