Integrate alrezni/v2_sweep5 into master

This commit is contained in:
Project Philly 2017-01-19 05:08:00 -08:00
Родитель c6fa6d1ee1 85c6aa8308
Коммит 020869ff17
48 изменённых файлов: 1159 добавлений и 387 удалений

@ -1 +1 @@
Subproject commit 8114febb0f491cd0fec4b60e389672cda8ababfb
Subproject commit 4b2396f36b8129d035a0166cd2d1a1e457404249

Просмотреть файл

@ -121,6 +121,7 @@ namespace CNTK
struct MinibatchInfo
{
bool atEndOfData;
bool atEndOfSweep;
size_t numberOfSamples;
NDArrayViewPtr trainingLossValue;
NDArrayViewPtr evalCriterionValue;
@ -3275,7 +3276,7 @@ namespace CNTK
///
/// A special value that can be used for the epochSize to indicate that the schedule is sweep-based.
///
static const size_t EntireSweep = 0;
static const size_t FullDataSweep = 0;
///
/// Create a schedule with a constant parameter value.
@ -3287,7 +3288,7 @@ namespace CNTK
/// schedule[0] is used for the first 'epochSize' samples, schedule[1] -- for the second,
/// and so on. The last value is then used repeatedly until the end of training.
///
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, UnitType unit, size_t epochSize = 1);
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, UnitType unit, size_t epochSize = FullDataSweep);
///
/// Create a schedule using the list of key-value pairs, where the key specifies
@ -3298,7 +3299,7 @@ namespace CNTK
/// the first 100 samples, then '0.1' is used for the second 200 samples,
/// after which the values is switched to '0.005'.
///
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize = 1);
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize = FullDataSweep);
///
/// Returns a value corresponding to the absolute sample (or sweep)
@ -3313,7 +3314,7 @@ namespace CNTK
///
UnitType Unit() const { return m_unit; }
bool IsSweepBased() const { return m_epochSize == EntireSweep; }
bool IsSweepBased() const { return m_epochSize == FullDataSweep; }
CNTK_API virtual ~TrainingParameterSchedule();
@ -3348,12 +3349,15 @@ namespace CNTK
TrainingParameterPerUnitSchedule(T value)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(value, U)
{ }
TrainingParameterPerUnitSchedule(const std::vector<T>& schedule, size_t epochSize = 1)
TrainingParameterPerUnitSchedule(const std::vector<T>& schedule,
size_t epochSize = TrainingParameterSchedule<T>::FullDataSweep)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
{ }
TrainingParameterPerUnitSchedule(const std::vector<std::pair<size_t, T>>& schedule, size_t epochSize = 1)
TrainingParameterPerUnitSchedule(const std::vector<std::pair<size_t, T>>& schedule,
size_t epochSize = TrainingParameterSchedule<T>::FullDataSweep)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
{ }
@ -3401,13 +3405,13 @@ namespace CNTK
ConvertToPerSampleValues();
}
MomentumAsTimeConstantSchedule(const std::vector<double>& schedule, size_t epochSize = 1)
MomentumAsTimeConstantSchedule(const std::vector<double>& schedule, size_t epochSize = FullDataSweep)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
{
ConvertToPerSampleValues();
}
MomentumAsTimeConstantSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = 1)
MomentumAsTimeConstantSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = FullDataSweep)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
{
ConvertToPerSampleValues();
@ -3424,9 +3428,10 @@ namespace CNTK
CNTK_API void ConvertToPerSampleValues();
};
///
/// A collection of additional options that affect parameter updates and
/// are applicable for all standard learners
///
struct AdditionalLearningOptions
{
double l1RegularizationWeight = 0.0;
@ -3452,7 +3457,7 @@ namespace CNTK
// Method to update the parameters associated with this learner. By returning false, this method indicates that
// learning has stopped for all of the parameters associated with this learner
//
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) = 0;
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd = false) = 0;
///
/// Returns the set of parameters associated with this learner.
@ -3610,9 +3615,9 @@ namespace CNTK
return m_communicator;
}
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t minibatchSampleCount) override
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t minibatchSampleCount, bool sweepEnd = false) override
{
MinibatchInfo info{ false, minibatchSampleCount };
MinibatchInfo info{ false, sweepEnd, minibatchSampleCount };
return Update(gradientValues, info);
}
@ -3726,6 +3731,11 @@ namespace CNTK
/// Optimize model parameters using the specified 'arguments' minibatch of training samples.
/// Returns false if all parameter learners indicate end of learning (through their Update method's return value).
///
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// An overload of the TrainMinibatch above that takes a map of variables and their values (as its first argument).
///
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
@ -3735,12 +3745,22 @@ namespace CNTK
/// for the 'outputs' for which the ValuePtr mapping was left null by the caller.
/// Returns false if all parameter learners indicate end of learning (through their Update method's return value).
///
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// An overload of the TrainMinibatch above that takes a map of variables and their values (as its first argument).
///
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer
/// Returns the average evaluation criterion value per sample for the tested minibatch of samples
///
CNTK_API double TestMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// An overload of the TestMinibatch above that takes a map of variables and their values (as its first argument).
///
CNTK_API double TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
@ -3806,8 +3826,8 @@ namespace CNTK
const DeviceDescriptor& computeDevice,
std::unordered_map<Variable, ValuePtr>& parameterGradients);
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice);
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice);
bool TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
bool TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice);
void Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState);
@ -3854,11 +3874,34 @@ namespace std {
namespace CNTK
{
///
/// A struct that combines the minibatch meta-data with the actual minibatch data.
/// The former includes the number of sequences and samples in the minibatch,
/// as well as the sweep-end flag, which is set to true to indicate that the minibatch
/// concludes a data sweep (i.e, it's the last minibatch at the end of the sweep).
///
struct MinibatchData
{
size_t m_numSequences;
size_t m_numSamples;
ValuePtr m_data;
MinibatchData() : MinibatchData(nullptr)
{}
// a convenience constructor to allow passing ValuePtr arguments in place
// of MinibatchData parameter (e.g., in Trainer::TrainMinibatch)
MinibatchData(ValuePtr value) : MinibatchData(value, 0)
{}
MinibatchData(ValuePtr value, size_t numSamples, bool sweepEnd = false)
: MinibatchData(value, numSamples, numSamples, sweepEnd)
{}
MinibatchData(ValuePtr value, size_t numSequences, size_t numSamples, bool sweepEnd)
: data(value), numberOfSequences(numSequences), numberOfSamples(numSamples), sweepEnd(sweepEnd)
{}
ValuePtr data;
size_t numberOfSequences;
size_t numberOfSamples;
bool sweepEnd;
};
///

Просмотреть файл

@ -160,6 +160,9 @@ namespace CNTK
enum class PrimitiveOpType : unsigned int;
enum class DataType : unsigned int;
struct MinibatchInfo;
struct MinibatchData;
class Serializer;
// Similar to make_shared except that it associates a custom deleter with the shared_ptr to ensure

Просмотреть файл

@ -84,7 +84,7 @@ namespace CNTK
break;
for (auto& currentStreamKV : computedMeanAndInvStdDevs)
CompositeFunction::PopulateComputationNodeValue<float>({ streamToDummyInputVariableMap[currentStreamKV.first], minibatchData[currentStreamKV.first].m_data }, streamToInputNodeMap[currentStreamKV.first], layoutsPopulated);
CompositeFunction::PopulateComputationNodeValue<float>({ streamToDummyInputVariableMap[currentStreamKV.first], minibatchData[currentStreamKV.first].data }, streamToInputNodeMap[currentStreamKV.first], layoutsPopulated);
ComputationNetwork::BumpEvalTimeStamp(allInputNodes);

Просмотреть файл

@ -147,6 +147,6 @@ namespace CNTK
if (info.IsEmpty())
return false;
return m_learner->Update(gradientValues, info.numberOfSamples);
return m_learner->Update(gradientValues, info.numberOfSamples, info.atEndOfSweep);
}
}

Просмотреть файл

@ -225,7 +225,7 @@ namespace CNTK
}
}
/*virtual*/ bool LearnerBase::Update(unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) /*override*/
/*virtual*/ bool LearnerBase::Update(unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd) /*override*/
{
if (LearningRate(trainingSampleCount) == 0.0)
{
@ -233,7 +233,10 @@ namespace CNTK
}
// make sure trainingSampleCount is a valid value
assert(trainingSampleCount > 0);
if (trainingSampleCount == 0)
{
InvalidArgument("Learner::Update(): cannot perform an update with an empty minibatch.");
}
for (const auto& parameter : Parameters())
{
@ -273,7 +276,11 @@ namespace CNTK
}
m_sampleCount += trainingSampleCount;
m_minibatchCount++;
// TODO: sweep count also needs to be updated.
if (sweepEnd)
{
m_sweepCount++;
}
return true;
}

Просмотреть файл

@ -17,7 +17,7 @@ namespace CNTK
class LearnerBase : public Learner
{
public:
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) override final;
virtual bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd = false) override final;
virtual Dictionary CreateCheckpoint() override final;

Просмотреть файл

@ -77,7 +77,7 @@ namespace CNTK
CompositeMinibatchSource::CompositeMinibatchSource(const Dictionary& configuration)
: m_epochEndReached(false),
m_prevMinibatchSize(0),
m_epochSize(MinibatchSource::InfinitelyRepeat),
m_maxNumSamplesToRead(MinibatchSource::InfinitelyRepeat),
m_randomizedWindow(MinibatchSource::DefaultRandomizationWindow),
m_truncationLength(0),
m_numWorkers(1),
@ -136,13 +136,7 @@ namespace CNTK
const wchar_t* epochSizeConfigurationKey = L"epochSize";
if (augmentedConfiguration.Contains(epochSizeConfigurationKey))
m_epochSize = augmentedConfiguration[epochSizeConfigurationKey].Value<size_t>();
if (m_epochSize == MinibatchSource::FullDataSweep)
m_epochSize = Microsoft::MSR::CNTK::requestDataSize;
// Setting big value, but not the max in order to aviod bit overflow.
else if (m_epochSize == MinibatchSource::InfinitelyRepeat)
m_epochSize = std::numeric_limits<size_t>::max() / 2;
m_maxNumSamplesToRead = augmentedConfiguration[epochSizeConfigurationKey].Value<size_t>();
const wchar_t* randomizedWindowConfigurationKey = L"randomizationWindow";
if (augmentedConfiguration.Contains(randomizedWindowConfigurationKey))
@ -212,9 +206,24 @@ namespace CNTK
epochConfig.m_workerRank = workerRank;
epochConfig.m_minibatchSizeInSamples = minibatchSizeInSamples;
epochConfig.m_truncationSize = m_truncationLength;
epochConfig.m_allowMinibatchesToCrossSweepBoundaries = true;
if (m_maxNumSamplesToRead == MinibatchSource::FullDataSweep)
{
epochConfig.m_totalEpochSizeInSamples = Microsoft::MSR::CNTK::requestDataSize;
}
else if (m_maxNumSamplesToRead == MinibatchSource::InfinitelyRepeat)
{
// Setting big value, but not the max in order to aviod bit overflow.
epochConfig.m_totalEpochSizeInSamples = std::numeric_limits<size_t>::max() / 2;
}
else
{
epochConfig.m_totalEpochSizeInSamples = m_maxNumSamplesToRead;
}
epochConfig.m_totalEpochSizeInSamples = m_epochSize;
epochConfig.m_epochIndex = 0;
m_matrices.clear();
std::unordered_set<InputStreamDescription> inputs;
@ -257,6 +266,7 @@ namespace CNTK
newConfig.m_workerRank = workerRank;
newConfig.m_minibatchSizeInSamples = minibatchSizeInSamples;
newConfig.m_truncationSize = m_truncationLength;
newConfig.m_allowMinibatchesToCrossSweepBoundaries = true;
m_shim->SetConfiguration(newConfig, inputDescriptions);
@ -267,9 +277,12 @@ namespace CNTK
auto hasData = m_shim->GetMinibatch(m_matrices);
m_epochEndReached = m_shim->IsEndOfEpoch();
if (m_epochEndReached && !hasData)
return m_minibatchData;
bool hasReachedSweepEnd = m_shim->IsEndOfSweep();
for (const auto& s: m_streamInfos)
{
auto input = m_matrices.GetInput(s.m_name);
@ -293,7 +306,7 @@ namespace CNTK
size_t numSamples = input.pMBLayout->GetActualNumSamples();
size_t numSequences = input.pMBLayout->GetNumSequences();
m_minibatchData[currentStreamInfo] = { numSequences, numSamples, minibatchValuePtr };
m_minibatchData[currentStreamInfo] = { minibatchValuePtr, numSequences, numSamples, hasReachedSweepEnd };
}
else
LogicError("Input data of type other than DataType::Float is currently unsupported by the CNTK built-in composite MinibatchSource!");

Просмотреть файл

@ -50,7 +50,7 @@ namespace CNTK
size_t m_numWorkers;
size_t m_workerRank;
size_t m_prevMinibatchSize;
size_t m_epochSize;
size_t m_maxNumSamplesToRead;
size_t m_randomizedWindow;
size_t m_truncationLength;
std::unordered_map<StreamInformation, MinibatchData> m_minibatchData;

Просмотреть файл

@ -116,6 +116,29 @@ namespace CNTK
return (numSamplesInDataArrayView - numMaskedSamples);
}
static std::unordered_map<Variable, ValuePtr> GetInputs(const std::unordered_map<Variable, MinibatchData>& arguments)
{
std::unordered_map<Variable, ValuePtr> inputs(arguments.size());
for (const auto& kv : arguments)
{
inputs[kv.first] = kv.second.data;
}
return inputs;
}
static bool IsAtSweepEnd(const std::unordered_map<Variable, MinibatchData>& arguments)
{
return std::any_of(arguments.begin(), arguments.end(), [](const std::pair<const Variable, MinibatchData>& kv)
{
return kv.second.sweepEnd;
});
}
double Trainer::TestMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
return TestMinibatch(GetInputs(arguments), computeDevice);
}
double Trainer::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
if (!m_aggregatedEvaluationFunction)
@ -123,12 +146,26 @@ namespace CNTK
// TODO: Should we refactor this code that is somewhat similar to the prologue of the TrainMinibatch function
std::unordered_map<Variable, ValuePtr> outputs = { { m_aggregatedEvaluationFunction, nullptr }, { m_testSampleCountVar, nullptr } };
m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice);
auto sampleCount = GetSampleCount(m_testSampleCountVar, outputs[m_testSampleCountVar]);
return (GetScalarValue(outputs[m_aggregatedEvaluationFunction]) / sampleCount);
}
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
std::unordered_map<Variable, ValuePtr> outputsToFetch = {};
return TrainMinibatch(arguments, outputsToFetch, computeDevice);
}
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
if (!m_distributed)
return TrainLocalMinibatch(GetInputs(arguments), outputsToFetch, IsAtSweepEnd(arguments), computeDevice);
return TrainDistributedMinibatch(GetInputs(arguments), outputsToFetch, IsAtSweepEnd(arguments), computeDevice);
}
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
std::unordered_map<Variable, ValuePtr> outputsToFetch = {};
@ -138,11 +175,11 @@ namespace CNTK
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
if (!m_distributed)
return TrainLocalMinibatch(arguments, outputsToFetch, computeDevice);
return TrainDistributedMinibatch(arguments, outputsToFetch, computeDevice);
return TrainLocalMinibatch(arguments, outputsToFetch, false, computeDevice);
return TrainDistributedMinibatch(arguments, outputsToFetch, false, computeDevice);
}
bool Trainer::TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
bool Trainer::TrainLocalMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
bool emptyMinibatch = arguments.empty() || (arguments.begin()->second == nullptr);
if (emptyMinibatch) // Nothing to train with.
@ -154,10 +191,10 @@ namespace CNTK
std::unordered_map<Parameter, NDArrayViewPtr> gradients;
for (const auto& parameter : m_combinedTrainingFunction->Parameters())
gradients[parameter] = parameterGradients[parameter]->Data();
return m_parameterLearners->Update(gradients, m_prevMinibatchNumSamples);
return m_parameterLearners->Update(gradients, m_prevMinibatchNumSamples, sweepEnd);
}
bool Trainer::TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
bool Trainer::TrainDistributedMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, bool sweepEnd, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
std::unordered_map<Parameter, NDArrayViewPtr> gradients;
auto modelParameters = m_combinedTrainingFunction->Parameters();
@ -184,7 +221,7 @@ namespace CNTK
evalCriterion = m_prevMinibatchAggregateEvalCriterionValue->Data();
}
MinibatchInfo info { arguments.empty(), m_prevMinibatchNumSamples, trainingLoss, evalCriterion };
MinibatchInfo info{ arguments.empty(), sweepEnd, m_prevMinibatchNumSamples, trainingLoss, evalCriterion };
bool updated = m_parameterLearners->Update(gradients, info);
m_prevMinibatchNumSamples = info.numberOfSamples;

Просмотреть файл

@ -93,7 +93,7 @@ namespace CNTK
if (!minibatchData.empty())
{
for (auto v : m_modelInputToMinibatchSourceStream)
minibatch.insert({ v.first, minibatchData[v.second].m_data });
minibatch.insert({ v.first, minibatchData[v.second].data });
}
OnMinibatchStart();

Просмотреть файл

@ -241,7 +241,7 @@ namespace CNTK
template <typename T>
TrainingParameterSchedule<T>::TrainingParameterSchedule(T value, UnitType unit)
: m_schedule({ make_pair(0, value) }), m_unit(unit), m_epochSize(EntireSweep)
: m_schedule({ make_pair(0, value) }), m_unit(unit), m_epochSize(FullDataSweep)
{
}
@ -268,13 +268,9 @@ namespace CNTK
template <typename T>
void TrainingParameterSchedule<T>::ConstructSchedule(const std::vector<std::pair<size_t, T>>& schedule)
{
if (m_epochSize == EntireSweep)
{
//Sweep based schedules are currently not functional (learners don't have sweep info).
NOT_IMPLEMENTED;
}
const auto epochSize = (m_epochSize == EntireSweep) ? 1 : m_epochSize;
// In case of the FullDataSweep, the scheduling unit is just 1 sweep,
// otherwise, it's the epoch size in samples.
const auto unitSize = (m_epochSize == FullDataSweep) ? 1 : m_epochSize;
if (schedule.size() == 0)
RuntimeError("TrainingParameterSchedule::ConstructSchedule : schedule is empty.");
@ -288,7 +284,7 @@ namespace CNTK
RuntimeError("TrainingParameterSchedule::ConstructSchedule : unit count in the 'schedule' argument cannot be 0.");
unitCount += (pair.first != 0) ? pair.first : 1;
m_schedule[epochSize * unitCount] = pair.second;
m_schedule[unitSize * unitCount] = pair.second;
}
}
@ -880,14 +876,14 @@ namespace CNTK
}
}
bool Learners::Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t sampleInMinibatch)
bool Learners::Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t sampleInMinibatch, bool sweepEnd)
{
bool anyUpdatesPerformed = false;
for (auto learner : m_learners)
{
std::unordered_map<Parameter, NDArrayViewPtr> learnerGradients;
GetLearnerGradients(learner, gradientValues, learnerGradients);
anyUpdatesPerformed |= learner->Update(learnerGradients, sampleInMinibatch);
anyUpdatesPerformed |= learner->Update(learnerGradients, sampleInMinibatch, sweepEnd);
}
return anyUpdatesPerformed;
}

Просмотреть файл

@ -501,7 +501,7 @@ namespace CNTK
public:
explicit Learners(const std::vector<LearnerPtr>& learners);
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount);
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd);
bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, MinibatchInfo& minibatchInfo);
std::vector<DictionaryValue> CreateCheckpoint();

Просмотреть файл

@ -29,7 +29,7 @@ BlockRandomizer::BlockRandomizer(
m_epochSize(SIZE_MAX),
m_globalSamplePosition(SIZE_MAX),
m_epochStartPosition(0),
m_sweepTotalNumberOfSamples(0),
m_sweepSizeInSamples(0),
m_chunkRandomizer(std::make_shared<ChunkRandomizer>(deserializer, randomizationRangeInSamples)),
m_multithreadedGetNextSequences(multithreadedGetNextSequence),
m_prefetchedChunk(CHUNKID_MAX),
@ -43,10 +43,10 @@ BlockRandomizer::BlockRandomizer(
m_sequenceRandomizer = std::make_shared<SequenceRandomizer>(verbosity, m_deserializer, m_chunkRandomizer);
// Calculate total number of samples.
m_sweepTotalNumberOfSamples = 0;
m_sweepSizeInSamples = 0;
for (auto const & chunk : m_deserializer->GetChunkDescriptions())
{
m_sweepTotalNumberOfSamples += chunk->m_numberOfSamples;
m_sweepSizeInSamples += chunk->m_numberOfSamples;
}
}
@ -63,7 +63,7 @@ void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
m_config = config;
if (config.m_totalEpochSizeInSamples == requestDataSize)
{
m_epochSize = m_sweepTotalNumberOfSamples;
m_epochSize = m_sweepSizeInSamples;
}
else
{
@ -92,7 +92,7 @@ void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
// Prepares a new sweep if needed.
void BlockRandomizer::PrepareNewSweepIfNeeded(size_t samplePosition)
{
size_t sweep = samplePosition / m_sweepTotalNumberOfSamples;
size_t sweep = samplePosition / m_sweepSizeInSamples;
if (m_sweep != sweep)
{
if (m_verbosity >= Notification)
@ -115,32 +115,94 @@ Sequences BlockRandomizer::GetNextSequences(size_t globalSampleCount, size_t loc
{
// Get next sequence descriptions.
Sequences result;
ClosedOpenChunkInterval windowRange;
m_sequenceBuffer.clear();
result.m_endOfEpoch = GetNextSequenceDescriptions(globalSampleCount, localSampleCount, m_sequenceBuffer, windowRange);
if (m_sequenceBuffer.size() == 0)
size_t numGlobalSamplesLoaded = 0, numLocalSamplesLoaded = 0;
do
{
return result;
assert(globalSampleCount > numGlobalSamplesLoaded && localSampleCount > numLocalSamplesLoaded);
bool atTheSweepBoundary = result.m_endOfSweep;
// in case when we continue filling up a minibatch that crosses a sweep boundary,
// make sure that it does not exceed the required number of samples. Set the atLeastOnceSequenceNeeded
// flag to false.
size_t numGlobalSamples = 0, numLocalSamples = 0;
std::tie(numGlobalSamples, numLocalSamples) =
LoadSequenceData(globalSampleCount - numGlobalSamplesLoaded,
localSampleCount - numLocalSamplesLoaded,
result, !atTheSweepBoundary);
if (atTheSweepBoundary && numGlobalSamples == 0)
{
break;
}
numGlobalSamplesLoaded += numGlobalSamples;
numLocalSamplesLoaded += numLocalSamples;
} while (m_config.m_allowMinibatchesToCrossSweepBoundaries &&
!result.m_endOfEpoch &&
result.m_endOfSweep &&
globalSampleCount > numGlobalSamplesLoaded &&
localSampleCount > numLocalSamplesLoaded);
m_cleaner.Clean(result);
return result;
}
std::pair<size_t, size_t> BlockRandomizer::LoadSequenceData(size_t globalSampleCount, size_t localSampleCount, Sequences& sequences, bool atLeastOneSequenceNeeded)
{
ClosedOpenChunkInterval windowRange;
m_sequenceBuffer.clear();
size_t numGlobalSamples = 0, numLocalSamples = 0; // actual number of samples to load (filled in from the sequence descriptions)
bool endOfSweep, endOfEpoch;
std::tie(endOfSweep, endOfEpoch, numGlobalSamples, numLocalSamples) = GetNextSequenceDescriptions(globalSampleCount, localSampleCount, m_sequenceBuffer, windowRange, atLeastOneSequenceNeeded);
sequences.m_endOfSweep |= endOfSweep;
sequences.m_endOfEpoch |= endOfEpoch;
assert(atLeastOneSequenceNeeded || (numGlobalSamples <= globalSampleCount && numLocalSamples <= localSampleCount));
if (numGlobalSamples == 0)
{
assert(!atLeastOneSequenceNeeded || sequences.m_endOfEpoch);
return {0, 0};
}
// Retrieve new data chunks if required.
LoadDataChunks(windowRange);
result.m_data.resize(m_streams.size(), std::vector<SequenceDataPtr>(m_sequenceBuffer.size()));
auto& data = sequences.m_data;
size_t offset = 0;
if (data.empty())
{
data.resize(m_streams.size(), std::vector<SequenceDataPtr>(m_sequenceBuffer.size()));
}
else
{
// sequence data is not empty, we're appending new items to exiting
// sequence data vectors.
offset = data.front().size();
for (auto& sequenceDataVector : data)
{
// make sure that all streams contain the same number of sequences
assert(sequenceDataVector.size() == offset);
sequenceDataVector.resize(offset + m_sequenceBuffer.size());
}
}
auto process = [&](int i) -> void {
const auto& description = m_sequenceBuffer[i];
std::vector<SequenceDataPtr> sequence;
std::vector<SequenceDataPtr> sequenceData;
auto it = m_chunks.find(description.m_chunk->m_original->m_id);
if (it == m_chunks.end())
{
LogicError("Invalid chunk requested.");
}
it->second->GetSequence(description.m_id, sequence);
it->second->GetSequence(description.m_id, sequenceData);
for (int j = 0; j < m_streams.size(); ++j)
{
result.m_data[j][i] = sequence[j];
assert(offset + i < data[j].size());
data[j][offset + i] = sequenceData[j];
}
};
@ -158,18 +220,16 @@ Sequences BlockRandomizer::GetNextSequences(size_t globalSampleCount, size_t loc
process(i);
}
m_cleaner.Clean(result);
// Now it is safe to start the new chunk prefetch.
ChunkIdType chunkToPrefetchNext = GetChunkToPrefetch(windowRange);
Prefetch(chunkToPrefetchNext);
return result;
return { numGlobalSamples, numLocalSamples };
}
// Get next sequence descriptions for that worker that do not exceed global and local sample count.
// Returns true if epoch end is reached.
bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange)
std::tuple<bool, bool, size_t, size_t> BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange, bool atLeastOneSequenceNeeded)
{
if (globalSampleCount == 0)
LogicError("Global sample count must not be zero.");
@ -179,17 +239,22 @@ bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size
PrepareNewSweepIfNeeded(m_globalSamplePosition);
auto sweepPosition = m_globalSamplePosition % m_sweepSizeInSamples;
auto epochEndPosition = m_epochSize + m_epochStartPosition;
// Check epoch end.
if (m_globalSamplePosition >= m_epochSize + m_epochStartPosition)
if (m_globalSamplePosition >= epochEndPosition)
{
return true;
auto reachedEndOfEpoch = true;
auto reachedEndOfSweep = (m_globalSamplePosition >= m_sweepSizeInSamples) && (sweepPosition == 0);
return std::make_tuple(reachedEndOfSweep, reachedEndOfEpoch, 0, 0);
}
// Global sample count should not exceed the epoch.
globalSampleCount = std::min(globalSampleCount, m_epochSize + m_epochStartPosition - m_globalSamplePosition);
globalSampleCount = std::min(globalSampleCount, epochEndPosition - m_globalSamplePosition);
// Global sample count should also not exceed the sweep.
globalSampleCount = std::min(globalSampleCount, (long)m_sweepTotalNumberOfSamples - m_globalSamplePosition % m_sweepTotalNumberOfSamples);
globalSampleCount = std::min(globalSampleCount, m_sweepSizeInSamples - sweepPosition);
if (globalSampleCount == 0)
LogicError("Global sample count must not result in zero.");
@ -197,12 +262,17 @@ bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size
std::function<bool(const RandomizedSequenceDescription*)> isLocalSequence =
[this](const RandomizedSequenceDescription* s) { return s->m_chunk->m_chunkId % m_config.m_numberOfWorkers == m_config.m_workerRank; };
size_t actualNumberOfGlobalSamples = m_sequenceRandomizer->GetNextSequenceDescriptions(
size_t actualNumberOfGlobalSamples = 0, actualNumberOfLocalSamples = 0;
std::tie(actualNumberOfGlobalSamples, actualNumberOfLocalSamples) = m_sequenceRandomizer->GetNextSequenceDescriptions(
globalSampleCount,
localSampleCount,
isLocalSequence,
windowRange,
result);
result,
atLeastOneSequenceNeeded);
if (actualNumberOfLocalSamples > actualNumberOfGlobalSamples)
LogicError("Local sample count cannot be greater than the global sample count.");
if (m_verbosity >= Debug)
fprintf(stderr, "BlockRandomizer::GetNextSequenceDescriptions(): getting %" PRIu64 " sequences for %" PRIu64 "/%" PRIu64 " requested local/global samples in sweep %" PRIu64 "\n",
@ -211,10 +281,15 @@ bool BlockRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size
globalSampleCount,
m_sweep);
// set "reachedEndOfSweep" to true if the minibatch is last in a sweep
auto reachedEndOfSweep = (sweepPosition + actualNumberOfGlobalSamples >= m_sweepSizeInSamples);
// set "reachedEndOfEpoch" to true if the current batch is last in an epoch.
auto reachedEndOfEpoch = (m_globalSamplePosition + actualNumberOfGlobalSamples >= epochEndPosition);
// Update the global sample position.
m_globalSamplePosition += actualNumberOfGlobalSamples;
// return true if the current batch is last in an epoch.
return m_globalSamplePosition >= m_epochSize + m_epochStartPosition;
return std::make_tuple(reachedEndOfSweep, reachedEndOfEpoch, actualNumberOfGlobalSamples, actualNumberOfLocalSamples);
}
// Retrieves chunk data based on the window information provided by SequenceRandomizer
@ -350,9 +425,9 @@ void BlockRandomizer::SetCurrentSamplePosition(size_t currentSamplePosition)
// Sets sequence cursor to the sequence that corresponds to the epoch start position.
// If last epoch ended in the middle of a sequence, the cursor is moved to the next sequence in the sweep.
size_t offsetInSweep = currentSamplePosition % m_sweepTotalNumberOfSamples;
size_t offsetInSweep = currentSamplePosition % m_sweepSizeInSamples;
size_t newOffset = m_sequenceRandomizer->Seek(offsetInSweep, m_sweep);
m_globalSamplePosition = m_sweep * m_sweepTotalNumberOfSamples + newOffset;
m_globalSamplePosition = m_sweep * m_sweepSizeInSamples + newOffset;
// Check if we have some data, if not set to the end of epoch.
if (m_config.m_workerRank >= m_chunkRandomizer->GetRandomizedChunks().size())

Просмотреть файл

@ -77,8 +77,21 @@ private:
// Load data for chunks if needed.
void LoadDataChunks(const ClosedOpenChunkInterval& windowRange);
// Get next sequence descriptions that do not exceed global and local sample count.
bool GetNextSequenceDescriptions(size_t globalSampleCount, size_t localSampleCount, std::vector<RandomizedSequenceDescription>& result, ClosedOpenChunkInterval& windowRange);
// Load actual sequence data up to the specified global/local sample count
// (or at least one sequence when atLeastOneSequenceNeeded is true),
// Returns the total number of global and local samples loaded.
std::pair<size_t, size_t> LoadSequenceData(size_t globalSampleCount, size_t localSampleCount, Sequences& sequence, bool atLeastOneSequenceNeeded);
// Gets the next sequence descriptions with the total number of samples not exceeding
// the sample count, when atLeastOneSequenceNeeded is false. Otherwise (when atLeastOneSequenceNeeded is true),
// returns at least one sequence description even when its length is greater than the required sample count.
// Returns a tuple containing "end of sweep", "end of epoch" flags and
// the total numbers of global and local samples to be processed.
std::tuple<bool, bool, size_t, size_t> GetNextSequenceDescriptions(size_t globalSampleCount,
size_t localSampleCount,
std::vector<RandomizedSequenceDescription>& result,
ClosedOpenChunkInterval& windowRange,
bool atLeastOneSequenceNeeded);
// Prepares a new sweep if needed.
void PrepareNewSweepIfNeeded(size_t samplePosition);
@ -105,7 +118,7 @@ private:
size_t m_sweep;
// Total number of samples in a sweep.
size_t m_sweepTotalNumberOfSamples;
size_t m_sweepSizeInSamples;
IDataDeserializerPtr m_deserializer;

Просмотреть файл

@ -17,7 +17,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_currentChunkPosition(CHUNKID_MAX),
m_globalSamplePosition(0),
m_globalSequencePosition(0),
m_totalNumberOfSamples(0),
m_sweepSizeInSamples(0),
m_currentSequencePositionInChunk(0),
m_multithreadedGetNextSequences(multithreadedGetNextSequences),
m_cleaner(maxNumberOfInvalidSequences)
@ -41,7 +41,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
RuntimeError("NoRandomizer: Expected input to contain samples, but the number of successfully read samples was 0.");
}
m_totalNumberOfSamples = sampleCount;
m_sweepSizeInSamples = sampleCount;
}
ChunkIdType NoRandomizer::GetChunkIndexOf(size_t samplePosition)
@ -55,7 +55,7 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
m_config = config;
if (m_config.m_totalEpochSizeInSamples == requestDataSize)
m_config.m_totalEpochSizeInSamples = m_totalNumberOfSamples;
m_config.m_totalEpochSizeInSamples = m_sweepSizeInSamples;
SetCurrentSamplePosition(m_config.m_totalEpochSizeInSamples * config.m_epochIndex);
}
@ -83,39 +83,38 @@ void NoRandomizer::GetNextSequenceDescriptions(size_t globalSampleCount, size_t
assert(globalSampleCount != 0);
assert(localSampleCount != 0);
if (globalSampleCount > std::numeric_limits<int>::max() ||
if (globalSampleCount > std::numeric_limits<int>::max() &&
localSampleCount > std::numeric_limits<int>::max())
RuntimeError("Global and local size of the minibatch cannot exceed max int.");
assert(m_sequenceWindow.size() != 0);
assert(m_chunkDescriptions[m_currentChunkPosition]->m_numberOfSequences > m_currentSequencePositionInChunk);
int localSamplesLeft = (int)localSampleCount;
int globalSamplesLeft = (int)globalSampleCount;
size_t numGlobalSamplesLoaded = 0, numLocalSamplesLoaded = 0;
result.reserve(localSampleCount);
result.clear();
while (globalSamplesLeft > 0 && localSamplesLeft > 0)
while (globalSampleCount > numGlobalSamplesLoaded && localSampleCount > numLocalSamplesLoaded)
{
const SequenceDescription& sequence = m_sequenceWindow[m_currentSequencePositionInChunk];
int sequenceLength = (int)sequence.m_numberOfSamples;
auto sequenceLength = sequence.m_numberOfSamples;
// Let's check whether we need to return this sequence or skip it.
bool isLocal = m_globalSequencePosition % m_config.m_numberOfWorkers == m_config.m_workerRank;
if (result.empty() ||
((localSamplesLeft >= sequenceLength) && (globalSamplesLeft >= sequenceLength)))
((localSampleCount - numLocalSamplesLoaded >= sequenceLength) && (globalSampleCount - numGlobalSamplesLoaded >= sequenceLength)))
{
if (isLocal) // Ok good to add it to the result.
{
result.push_back(sequence);
localSamplesLeft -= sequence.m_numberOfSamples;
numLocalSamplesLoaded += sequence.m_numberOfSamples;
}
}
else // otherwise there is no room, return what we have.
break;
globalSamplesLeft -= sequence.m_numberOfSamples;
numGlobalSamplesLoaded += sequence.m_numberOfSamples;
m_globalSamplePosition += sequence.m_numberOfSamples;
m_globalSequencePosition++;
@ -141,25 +140,35 @@ Sequences NoRandomizer::GetNextSequences(size_t globalSampleCount, size_t localS
if (m_globalSamplePosition >= endOfEpochPosition)
{
result.m_endOfEpoch = true;
result.m_endOfSweep = (m_globalSamplePosition >= m_sweepSizeInSamples) &&
(m_globalSamplePosition % m_sweepSizeInSamples == 0);
return result;
}
// Check we do not go over epoch.
globalSampleCount = std::min(globalSampleCount, endOfEpochPosition - m_globalSamplePosition);
// Check that we do not go over the sweep.
size_t sweepPosition = m_globalSamplePosition % m_totalNumberOfSamples;
globalSampleCount = std::min(globalSampleCount, m_totalNumberOfSamples - sweepPosition);
if (!m_config.m_allowMinibatchesToCrossSweepBoundaries)
{
// Cut down the required sample count if we're not allowed to go over the
// sweep boundary
size_t sweepPosition = m_globalSamplePosition % m_sweepSizeInSamples;
globalSampleCount = std::min(globalSampleCount, m_sweepSizeInSamples - sweepPosition);
}
if (globalSampleCount == 0)
LogicError("Global sample count must not result in zero.");
auto sweepIndex = m_globalSamplePosition / m_sweepSizeInSamples;
m_sequenceBuffer.clear();
GetNextSequenceDescriptions(globalSampleCount, localSampleCount, m_sequenceBuffer);
// m_globalSamplePosition is already shifted in GetNextSequenceDescriptions() by the current minibatch size.
// Set the end-of-epoch flag (true when the current batch is last in an epoch).
result.m_endOfEpoch = (m_globalSamplePosition >= endOfEpochPosition);
result.m_endOfSweep = sweepIndex != m_globalSamplePosition / m_sweepSizeInSamples;
if (m_sequenceBuffer.size() == 0)
{
return result;
@ -229,7 +238,7 @@ void NoRandomizer::SetCurrentSamplePosition(size_t samplePosition)
{
m_currentSequencePositionInChunk = 0;
m_globalSamplePosition = samplePosition;
size_t sweepSamplePosition = m_globalSamplePosition % m_totalNumberOfSamples;
size_t sweepSamplePosition = m_globalSamplePosition % m_sweepSizeInSamples;
ChunkIdType chunkIndex = GetChunkIndexOf(sweepSamplePosition);
if (chunkIndex != m_currentChunkPosition)

Просмотреть файл

@ -88,7 +88,7 @@ private:
size_t m_globalSequencePosition;
// Total number of samples in the sweep.
size_t m_totalNumberOfSamples;
size_t m_sweepSizeInSamples;
// Temp buffer to avoid allocations.
std::vector<SequenceDescription> m_sequenceBuffer;

Просмотреть файл

@ -32,6 +32,11 @@ struct ReaderConfiguration
size_t m_workerRank; // Rank of the Open MPI worker, worker rank has to be less than the number of workers
size_t m_minibatchSizeInSamples; // Maximum minibatch size for the epoch in samples
size_t m_truncationSize; // Truncation size in samples for truncated BPTT mode.
// This flag indicates whether the minibatches are allowed to overlap the boundary
// between sweeps (in which case, they can contain data from different sweeps) or
// if they need to be trimmed at the sweep end.
bool m_allowMinibatchesToCrossSweepBoundaries{ false };
};
// TODO: Should be deprecated.
@ -87,6 +92,10 @@ typedef std::shared_ptr<StreamMinibatch> StreamMinibatchPtr;
// Represents a single minibatch, that contains information about all streams.
struct Minibatch
{
// Indicates that this minibatch is either adjacent to the data sweep boundary
// (-----<minibatch>|---) or crosses the boundary (-----<mini|batch>---).
bool m_endOfSweep;
// Indicates that the end of epoch has been reached.
// It is set to true for the last minibatch, there still
// can be data in m_data field even if this flag is set.
@ -95,11 +104,8 @@ struct Minibatch
// Minibatch data
std::vector<StreamMinibatchPtr> m_data;
Minibatch() : m_endOfEpoch(false)
{
}
Minibatch(bool endOfEpoch) : m_endOfEpoch(endOfEpoch)
Minibatch(bool endOfSweep = false, bool endOfEpoch = false)
: m_endOfSweep(endOfSweep), m_endOfEpoch(endOfEpoch)
{
}
};

Просмотреть файл

@ -28,6 +28,7 @@ ReaderShim<ElemType>::ReaderShim() :
m_dataTransferers(2, DataTransfererPtr()),
m_currentDataTransferIndex(0),
m_endOfEpoch(false),
m_endOfSweep(false),
m_currentSamplePosition(0),
m_reader(nullptr),
m_factory(nullptr)
@ -274,6 +275,7 @@ bool ReaderShim<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
m_currentSamplePosition = m_reader->GetCurrentSamplePosition();
m_endOfEpoch = result.m_isEndOfEpoch;
m_endOfSweep = result.m_isEndOfSweep;
if (m_endOfEpoch && !result.m_isDataAvailable)
{
// No data and end of epoch, simply return.
@ -357,7 +359,7 @@ typename ReaderShim<ElemType>::PrefetchResult ReaderShim<ElemType>::PrefetchMini
// If there is no data we can simply return.
if (minibatch.m_data.empty())
return PrefetchResult{ minibatch.m_endOfEpoch, false };
return PrefetchResult{ minibatch.m_endOfSweep, minibatch.m_endOfEpoch, false };
// Ok we have some data. Let's load it to GPU.
// But before we need to make sure that corresponding compute has already finished from the last iteration.
@ -380,7 +382,7 @@ typename ReaderShim<ElemType>::PrefetchResult ReaderShim<ElemType>::PrefetchMini
if (m_dataTransferers[currentDataTransferIndex])
m_dataTransferers[currentDataTransferIndex]->RecordCPUToGPUCopy();
return PrefetchResult{ minibatch.m_endOfEpoch, true };
return PrefetchResult{ minibatch.m_endOfSweep, minibatch.m_endOfEpoch, true };
}

Просмотреть файл

@ -100,9 +100,15 @@ public:
return m_endOfEpoch;
}
bool IsEndOfSweep() const
{
return m_endOfSweep;
}
private:
struct PrefetchResult
{
bool m_isEndOfSweep;
bool m_isEndOfEpoch;
bool m_isDataAvailable;
};
@ -113,6 +119,7 @@ private:
ReaderPtr m_reader;
ReaderFactory m_factory;
bool m_endOfEpoch;
bool m_endOfSweep;
size_t m_numParallelSequences;

Просмотреть файл

@ -20,8 +20,13 @@ struct Sequences
// Indices in the outer vector have to correspond to the stream ids returned from the GetStreamDescriptions().
std::vector<std::vector<SequenceDataPtr>> m_data;
// Indicates whether the returned data comes from a sweep end or
// crosses a sweep boundary (and as a result includes sequences
// from different sweeps).
bool m_endOfSweep{ false };
// Indicates whether the epoch ends with the data returned.
bool m_endOfEpoch = false;
bool m_endOfEpoch{ false };
};
class SequenceEnumerator;

Просмотреть файл

@ -41,7 +41,7 @@ Minibatch SequencePacker::ReadMinibatch()
auto sequences = m_sequenceEnumerator->GetNextSequences(m_globalMinibatchSizeInSamples, m_localMinibatchSizeInSamples);
const auto& batch = sequences.m_data;
Minibatch minibatch(sequences.m_endOfEpoch);
Minibatch minibatch(sequences.m_endOfSweep, sequences.m_endOfEpoch);
if (batch.empty())
return minibatch;

Просмотреть файл

@ -63,26 +63,28 @@ namespace Microsoft { namespace MSR { namespace CNTK {
RandomizeNextChunkIfNeeded();
}
// Gets the next randomized sequence descriptions not exceeding the global and local sample count.
// Whether the sequence is considered local is defined by the isLocalSequence predicate.
// Returns how many global samples have been read.
size_t SequenceRandomizer::GetNextSequenceDescriptions(
// Gets the next randomized sequence descriptions not exceeding the global and local sample count,
// when atLeastOneSequenceNeeded is false. Otherwise (when atLeastOneSequenceNeeded is true),
// returns at least one sequence description even when its length is greater than the required sample counts.
// Whether a sequence is considered local is defined by the isLocalSequence predicate.
// Returns a pair whose first element indicates the number of global samples read,
// and second -- the number of local samples read (== sum of number of sample over all elements in the
// 'sequences' vector).
std::pair<size_t, size_t> SequenceRandomizer::GetNextSequenceDescriptions(
size_t globalSampleCount,
size_t localSampleCount,
const std::function<bool(const RandomizedSequenceDescription*)>& isLocalSequence,
ClosedOpenChunkInterval& requiredChunks,
std::vector<RandomizedSequenceDescription>& sequences)
std::vector<RandomizedSequenceDescription>& sequences,
bool atLeastOneSequenceNeeded)
{
assert(globalSampleCount != 0);
assert(localSampleCount != 0);
if (globalSampleCount > std::numeric_limits<int>::max() ||
if (globalSampleCount > std::numeric_limits<int>::max() &&
localSampleCount > std::numeric_limits<int>::max())
RuntimeError("Global and local size of the minibatch cannot exceed max int.");
int localSamplesLeft = (int)localSampleCount;
int globalSamplesLeft = (int)globalSampleCount;
// Initialize the range to the current chunk.
requiredChunks.m_begin = (ChunkIdType)std::min(m_currentChunkCursor, m_randomizedChunks.size() - 1);
requiredChunks.m_end = requiredChunks.m_begin + 1;
@ -90,9 +92,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
sequences.reserve(localSampleCount);
sequences.clear();
size_t totalSamplesRead = 0;
size_t globalSamplesRead = 0, localSamplesRead = 0;
while (m_currentChunkCursor < m_randomizedChunks.size() &&
globalSamplesLeft > 0 && localSamplesLeft > 0)
(localSamplesRead < localSampleCount && globalSamplesRead < globalSampleCount))
{
size_t sequenceOffsetInsideChunk = m_currentSequenceCursor - m_randomizedChunks[m_currentChunkCursor].m_sequencePositionStart;
const RandomizedSequenceDescription* sequence = &m_sequenceWindow[m_currentChunkCursor - m_chunkWindowBegin][sequenceOffsetInsideChunk];
@ -100,20 +102,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
bool isLocal = isLocalSequence(sequence);
// Let's check whether we need to return this sequence or skip it.
if (sequences.empty() ||
((localSamplesLeft >= sequenceLength) && (globalSamplesLeft >= sequenceLength)))
if ((sequences.empty() && atLeastOneSequenceNeeded) ||
((localSamplesRead + sequenceLength <= localSampleCount) && (globalSamplesRead + sequenceLength <= globalSampleCount)))
{
if (isLocal) // Ok good to add it to the result.
{
sequences.push_back(*sequence);
localSamplesLeft -= sequenceLength;
localSamplesRead += sequenceLength;
}
// even when the next sequence is not local, somebody else would return it, so
// we need to ivalidate the 'atLeastOneSequenceNeeded' flag.
atLeastOneSequenceNeeded = false;
}
else // otherwise there is no room, return what we have.
break;
totalSamplesRead += sequenceLength;
globalSamplesLeft -= sequenceLength;
globalSamplesRead += sequenceLength;
// Update the required chunk window.
requiredChunks.m_begin = std::min(m_randomizedChunks[m_currentChunkCursor].m_randomizationWindow.m_begin, requiredChunks.m_begin);
@ -130,7 +134,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
return totalSamplesRead;
return { globalSamplesRead, localSamplesRead };
}
// Move the chunk cursor to the next chunk, randomizing more sequences if necessary.

Просмотреть файл

@ -45,15 +45,20 @@ public:
// If the offset points in the middle of last sequence, the end of the sweep is returned.
size_t Seek(size_t sweepSampleOffset, size_t sweep);
// Gets the next randomized sequence descriptions not exceeding the global and local sample count.
// Gets the next randomized sequence descriptions not exceeding the global and local sample count,
// when atLeastOneSequenceNeeded is false. Otherwise (when atLeastOneSequenceNeeded is true),
// returns at least one sequence description even when its length is greater than the required sample count.
// Whether a sequence is considered local is defined by the isLocalSequence predicate.
// The return value is how many global samples have been read.
size_t GetNextSequenceDescriptions(
// Returns a pair whose first element indicates the number of global samples read,
// and second -- the number of local samples read (== sum of number of sample over all elements in the
// 'sequences' vector).
std::pair<size_t, size_t> GetNextSequenceDescriptions(
size_t globalSampleCount,
size_t localSampleCount,
const std::function<bool(const RandomizedSequenceDescription*)>& isLocalSequence,
ClosedOpenChunkInterval& requiredChunks,
std::vector<RandomizedSequenceDescription>& sequences);
std::vector<RandomizedSequenceDescription>& sequences,
bool atLeastOneSequenceNeeded = true);
private:
DISABLE_COPY_AND_MOVE(SequenceRandomizer);

Просмотреть файл

@ -7,7 +7,6 @@
#define _SCL_SECURE_NO_WARNINGS
#include <cmath>
#include <deque>
#include "TruncatedBpttPacker.h"
#include "ReaderUtil.h"
@ -38,9 +37,10 @@ public:
}
// Adds a new sequence to the end of the slot.
void PushSequence(SequenceDataPtr s)
void PushSequence(SequenceDataPtr s, bool endOfSweep)
{
m_sequences.push_back(s);
m_endOfSweepFlags.push_back(endOfSweep);
m_length += s->m_numberOfSamples;
}
@ -51,13 +51,16 @@ public:
}
// Pops the front sequence at the beginning of the slot.
void PopSequence()
bool PopSequence()
{
assert(!m_sequences.empty());
m_sampleCursor = 0;
m_sampleOffset = 0;
m_length -= m_sequences.front()->m_numberOfSamples;
m_sequences.pop_front();
bool endOfSweepFlag = m_endOfSweepFlags.front();
m_endOfSweepFlags.pop_front();
return endOfSweepFlag;
}
// Contains the current sample cursor in the first sequence(m_sequences.front()) of the slot.
@ -72,6 +75,10 @@ public:
private:
// Prepared sequences.
deque<SequenceDataPtr> m_sequences;
// For each 'in-flight' sequence we keep a flag that indicate whether
// the sequence data comes from an the end of a sweep.
std::deque<bool> m_endOfSweepFlags;
// Contains the size of the slot in samples (accumulated over all m_sequences).
size_t m_length;
@ -155,7 +162,7 @@ void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, co
m_sequenceBufferPerStream.clear();
// Preparing the buffers.
// Preparing the buffers.
for (int j = 0; j < m_streamBuffers.size(); ++j)
for (int i = 0; i < m_outputStreamDescriptions.size(); ++i)
{
@ -166,25 +173,22 @@ void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, co
}
}
// Filling in the initial set of sequences
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
{
ReadSequencesToSlot(slotIndex);
}
FillOutAvailableSlots();
}
Minibatch TruncatedBPTTPacker::ReadMinibatch()
{
Minibatch result;
FillOutAvailableSlots();
// Currently all we expect sequences of identical length between different streams,
// so it is sufficient to check a single stream only.
if (m_sequenceBufferPerStream.front()->NothingToPack())
{
result.m_endOfEpoch = true;
return result;
{
return Minibatch(/*endOfSweep = */false,/*endOfEpoch = */ true);
}
Minibatch result;
// Iterating over the streams/slots and packing them into the minibatch.
for (size_t streamIndex = 0; streamIndex < m_outputStreamDescriptions.size(); ++streamIndex)
{
@ -192,7 +196,7 @@ Minibatch TruncatedBPTTPacker::ReadMinibatch()
size_t sequenceId = 0;
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
{
PackSlot(streamIndex, slotIndex, sequenceId);
result.m_endOfSweep |= PackSlot(streamIndex, slotIndex, sequenceId);
}
StreamMinibatchPtr m = make_shared<StreamMinibatch>();
@ -203,17 +207,18 @@ Minibatch TruncatedBPTTPacker::ReadMinibatch()
m_currentBufferIndex = (m_currentBufferIndex + 1) % m_numberOfBuffers;
// Eagerly set the end of epoch flag if all the data have been packed.
result.m_endOfEpoch = m_sequenceBufferPerStream.front()->NothingToPack();
return result;
}
// Packs a slot of sequences into the minibatch.
void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId)
bool TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId)
{
bool containsEndOfSweepSequence = false;
auto& slot = m_sequenceBufferPerStream[streamIndex]->m_slots[slotIndex];
// Fill free space in the slot.
ReadSequencesToSlot(slotIndex);
// Let's see how much samples we need to read.
size_t numberOfSamples = min(m_config.m_truncationSize, slot.AvailableNumberOfSamples());
if (numberOfSamples == 0)
@ -223,7 +228,7 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
// Check that nothing is in the slot any more.
assert(slot.IsEmpty());
return;
return false;
}
size_t sampleSize = GetSampleSize(m_inputStreamDescriptions[streamIndex]);
@ -247,7 +252,7 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
if (slot.m_sampleCursor >= slot.FrontSequence()->m_numberOfSamples)
{
// Starting a new sequence. Have to reset current pointers and add it to the minibatch layout.
slot.PopSequence();
containsEndOfSweepSequence |= slot.PopSequence();
//Adding next sequence to the minibatch.
m_currentLayouts[streamIndex]->AddSequence(
@ -290,7 +295,7 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
// Cleaning up the last sequence we have just read if needed.
if (slot.m_sampleCursor >= slot.FrontSequence()->m_numberOfSamples)
{
slot.PopSequence();
containsEndOfSweepSequence |= slot.PopSequence();
}
// Adding the last gap if there is one.
@ -302,35 +307,59 @@ void TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t&
numberOfSamples,
m_config.m_truncationSize);
}
return containsEndOfSweepSequence;
}
void TruncatedBPTTPacker::FillOutAvailableSlots()
{
// Filling out any available spaces
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
{
ReadSequencesToSlot(slotIndex);
}
}
void TruncatedBPTTPacker::ReadSequencesToSlot(size_t slotIndex)
{
const auto& slot = m_sequenceBufferPerStream.front()->m_slots[slotIndex];
while (m_config.m_truncationSize >= slot.AvailableNumberOfSamples())
const auto& firstStreamSlot = m_sequenceBufferPerStream.front()->m_slots[slotIndex];
while (m_config.m_truncationSize >= firstStreamSlot.AvailableNumberOfSamples())
{
// We need a single sequence, potentially we can request (m_truncationSize - slot.AvailableNumberOfSamples())
// to be more efficient. In reality the truncation size usually is less the sequence size.
// Bptt always operates on a local timeline, so we do not limit the global minibatch count.
auto s = m_sequenceEnumerator->GetNextSequences(SIZE_MAX, 1);
const auto& sequences = m_sequenceEnumerator->GetNextSequences(SIZE_MAX, 1);
// assert that number of input streams == number of output streams --
// this does not have to be the case in general, but the current
// implementation makes this implicit assumption, so let's make it
// explicit instead until we can get rid of it altogether.
assert(sequences.m_endOfEpoch || sequences.m_data.size() == m_outputStreamDescriptions.size());
const auto& data = sequences.m_data;
// Adding sequence to the slot for all streams.
for (size_t i = 0; i < s.m_data.size(); ++i)
for (size_t streamIndex = 0; streamIndex < data.size(); ++streamIndex)
{
assert(s.m_data[i].size() == 1);
assert(data[streamIndex].size() == 1);
const auto& streamSequenceDataVector = data[streamIndex];
auto& slot = m_sequenceBufferPerStream[streamIndex]->m_slots[slotIndex];
// Check that all sequences are of the same length.
if (s.m_data.front().front()->m_numberOfSamples != s.m_data[i].front()->m_numberOfSamples)
if (data.front().front()->m_numberOfSamples != streamSequenceDataVector.front()->m_numberOfSamples)
{
RuntimeError("For BPTT sequences between different input stream should have the same length.");
}
slot.PushSequence(streamSequenceDataVector.front(), sequences.m_endOfSweep);
m_sequenceBufferPerStream[i]->m_slots[slotIndex].PushSequence(s.m_data[i].front());
assert(firstStreamSlot.AvailableNumberOfSamples() == slot.AvailableNumberOfSamples());
}
if (s.m_endOfEpoch)
if (sequences.m_endOfEpoch)
{
break;
return;
}
}
}

Просмотреть файл

@ -30,8 +30,13 @@ public:
virtual void SetConfiguration(const ReaderConfiguration& config, const std::vector<MemoryProviderPtr>& memoryProviders) override;
private:
// Iterates over all (m_parallelNumberOfSequences) slots,
// pulling in and filling out those slots with new sequence data,
// for which AvailableNumberOfSamples (= current size in samples) < m_truncationSize.
void FillOutAvailableSlots();
// Reads sequences to slot with the specified index.
// Number of slots = m_parallelNumberOfSequences
// Number of slots = m_parallelNumberOfSequences.
void ReadSequencesToSlot(size_t slotIndex);
// Packs a slot into the data buffer.
@ -39,7 +44,9 @@ private:
// For each new input, sequence id is reset to 0, and incremented each time
// a sequence is added to the layout. This allows layouts corresponding to different
// inputs to have consistent sequence ids.
void PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId);
// Returns a boolean indicating if a packed data contains a sequence
// (i.e., sequence tail) that was read last in a data sweep.
bool PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId);
virtual MBLayoutPtr CreateMBLayout(const StreamBatch& batch)
{

Просмотреть файл

@ -22,7 +22,7 @@ TestDataDir=$TEST_RUN_DIR/TestData
mkdir $TestDataDir
cp -R $DataSourceDir/MNIST/v0/Train-28x28_cntk_text.txt $TestDataDir || exit $?
cp -R $DataSourceDir/CIFAR/v0/cifar-10-batches-py $TestDataDir || exit $?
cp -R $TEST_DIR/../../Simple2d/Data/SimpleDataTrain_cntk_text.txt $TestDataDir || exit $?
cp -R $TEST_DIR/../../Simple2d/Data/SimpleDataT*_cntk_text.txt $TestDataDir || exit $?
cp -R $TEST_DIR/../../Text/SequenceClassification/Data/Train.ctf $TestDataDir || exit $?
cp -R $TEST_DIR/../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf $TestDataDir || exit $?
cp -R $TEST_DIR/../../../../Examples/Speech/AN4/Data/* $TestDataDir || exit $?

Просмотреть файл

@ -16,7 +16,7 @@ def test_cntk_103_mnist_feedforwardnetwork_noErrors(nb):
for output in cell['outputs'] if output.output_type == "error"]
assert errors == []
expectedEvalErrorByDeviceId = { -1: 1.90, 0: 1.85 }
expectedEvalErrorByDeviceId = { -1: 1.67, 0: 1.71 }
def test_cntk_103_mnist_feedforwardnetwork_evalCorrect(nb, device_id):
testCell = [cell for cell in nb.cells

Просмотреть файл

@ -31,6 +31,6 @@ def test_cntk_202_language_understanding_trainerror(nb):
pass
except KeyError:
pass
expectedMetrics = [2.8, 1.9, 2.2, 2.3]
expectedMetrics = [2.8, 1.9, 2.2, 2.0]
# TODO tighten tolerances
assert numpy.allclose(expectedMetrics, metrics, atol=0.2)

Просмотреть файл

@ -1266,8 +1266,8 @@ Legacy configuration is used for truncated BPTT mode, please adapt the config to
Set current path to: C:/repo/cntk_github6/CNTK/Tests/EndToEndTests/UnitTests/ReaderTests
Test module "ReaderTests" has passed with:
122 test cases out of 122 passed
22235579 assertions out of 22235579 passed
127 test cases out of 127 passed
22629927 assertions out of 22629927 passed
Test suite "ReaderTestSuite" has passed with:
92 test cases out of 92 passed
@ -1566,11 +1566,8 @@ Test module "ReaderTests" has passed with:
2 assertions out of 2 passed
Test suite "ReaderLibTests" has passed with:
18 test cases out of 18 passed
2265 assertions out of 2265 passed
Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with:
16 assertions out of 16 passed
23 test cases out of 23 passed
396613 assertions out of 396613 passed
Test case "ReaderLibTests/CheckSetCurrentCursorForRandomizers" has passed with:
12 assertions out of 12 passed
@ -1589,8 +1586,23 @@ Test module "ReaderTests" has passed with:
Test case "ReaderLibTests/BlockRandomizerInstantiate" has passed
Test case "ReaderLibTests/BlockRandomizerOneEpoch" has passed with:
68 assertions out of 68 passed
Test case "ReaderLibTests/TestRandomization_FirstEpoch" has passed with:
1476 assertions out of 1476 passed
Test case "ReaderLibTests/TestRandomization_SecondEpoch" has passed with:
1476 assertions out of 1476 passed
Test case "ReaderLibTests/TestRandomization_TwoSweeps" has passed with:
3372 assertions out of 3372 passed
Test case "ReaderLibTests/TestRandomization_TwoSweeps_WithSequences" has passed with:
196668 assertions out of 196668 passed
Test case "ReaderLibTests/TestRandomization_TwoSweeps_AllowToCrossSweepBoundary" has passed with:
3204 assertions out of 3204 passed
Test case "ReaderLibTests/TestRandomization_TwoSweeps_AllowToCrossSweepBoundary_WithSequences" has passed with:
189756 assertions out of 189756 passed
Test case "ReaderLibTests/BlockRandomizerOneEpochWithChunks1" has passed with:
68 assertions out of 68 passed
@ -1598,8 +1610,8 @@ Test module "ReaderTests" has passed with:
Test case "ReaderLibTests/BlockRandomizerOneEpochWithChunks2" has passed with:
128 assertions out of 128 passed
Test case "ReaderLibTests/BlockRandomizerChaosMonkey" has passed with:
1836 assertions out of 1836 passed
Test case "ReaderLibTests/RandomizerChaosMonkey" has passed with:
300 assertions out of 300 passed
Test case "ReaderLibTests/BlockRandomizerOneEpochLegacyRandomization" has passed with:
68 assertions out of 68 passed
@ -1607,6 +1619,9 @@ Test module "ReaderTests" has passed with:
Test case "ReaderLibTests/NoRandomizerOneEpoch" has passed with:
35 assertions out of 35 passed
Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with:
16 assertions out of 16 passed
Test case "ReaderLibTests/DefaultCorpusDescriptor" has passed with:
2 assertions out of 2 passed

Просмотреть файл

@ -2,10 +2,10 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include <numeric>
#include <random>
#include "NoRandomizer.h"
#include "DataDeserializer.h"
#include "BlockRandomizer.h"
@ -79,7 +79,7 @@ private:
vector<vector<float>> m_sequenceData;
public:
MockDeserializer(size_t numChunks, size_t numSequencesPerChunks, vector<float>& data, uint32_t sequenceLength = 1)
MockDeserializer(size_t numChunks, size_t numSequencesPerChunks, const vector<float>& data, uint32_t sequenceLength = 1)
: m_numChunks(numChunks),
m_numSequencesPerChunk(numSequencesPerChunks),
m_sampleLayout(make_shared<TensorShape>(1)),
@ -170,49 +170,7 @@ void BlockRandomizerInstantiateTest(bool prefetch)
{
vector<float> data;
auto mockDeserializer = make_shared<MockDeserializer>(0, 0, data);
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch);
}
BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers)
{
size_t chunkSizeInSamples = 10000;
size_t sweepNumberOfSamples = 500000;
uint32_t maxSequenceLength = 300;
size_t randomizationWindow = chunkSizeInSamples * 5;
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
auto blockRandomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto noRandomizer = make_shared<NoRandomizer>(deserializer, false);
auto test = [](SequenceEnumeratorPtr r, size_t epochSize)
{
auto firstEpoch = ReadFullEpoch(r, epochSize, 0);
auto firstCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(firstCursor, firstEpoch.size());
auto secondEpoch = ReadFullEpoch(r, epochSize, 1);
auto secondCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(secondCursor - firstCursor, secondEpoch.size());
auto thirdEpoch = ReadFullEpoch(r, epochSize, 2);
auto thirdCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(thirdCursor - secondCursor, thirdEpoch.size());
auto anotherSecondEpoch = ReadFullEpoch(r, epochSize, 1);
auto anotherSecondCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(anotherSecondCursor, secondCursor);
};
// Inside sweep
size_t epochSize = 50000;
test(blockRandomizer, epochSize);
test(noRandomizer, epochSize);
// Between sweeps
epochSize = (size_t)(sweepNumberOfSamples / 1.5);
test(blockRandomizer, epochSize);
test(noRandomizer, epochSize);
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch, false);
}
BOOST_AUTO_TEST_CASE(CheckSetCurrentCursorForRandomizers)
@ -223,11 +181,11 @@ BOOST_AUTO_TEST_CASE(CheckSetCurrentCursorForRandomizers)
size_t randomizationWindow = chunkSizeInSamples * 5;
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
auto expectedBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto expectedNo = make_shared<NoRandomizer>(deserializer);
auto expectedBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
auto expectedNo = make_shared<NoRandomizer>(deserializer, false);
auto underTestBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto unterTestNo = make_shared<NoRandomizer>(deserializer);
auto underTestBlock = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
auto unterTestNo = make_shared<NoRandomizer>(deserializer, false);
auto test = [](SequenceEnumeratorPtr expected, SequenceEnumeratorPtr underTest, size_t epochSize)
{
@ -292,7 +250,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToEarlierEpochBetweenSweeps)
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
// Let's randomize complete sweep, so that we have a baseline.
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
// Let's read all sequences from the first three sweeps in the randomized order.
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
@ -330,7 +288,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToEarlierEpochInTheSweep)
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
// Let's randomize complete sweep, so that we have a baseline.
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
// Let's read all sequences from the first three sweeps in the randomized order.
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
@ -361,7 +319,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToSameEpochInTheSweep)
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
// Let's randomize complete sweep, so that we have a baseline.
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
// Let's read all sequences from the first three sweeps in the randomized order.
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
@ -388,7 +346,7 @@ BOOST_AUTO_TEST_CASE(RandRollbackToSameEpochInBigRandomizationWindow)
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
// Let's randomize complete sweep, so that we have a baseline.
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true);
auto randomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
// Let's read all sequences from the first three sweeps in the randomized order.
auto firstSweep = ReadFullSweep(randomizer, 0, sweepNumberOfSamples);
@ -421,45 +379,221 @@ BOOST_AUTO_TEST_CASE(BlockRandomizerInstantiate)
BlockRandomizerInstantiateTest(true);
}
void BlockRandomizerOneEpochTest(bool prefetch)
void OneEpochRandomizationTest(SequenceEnumerator& randomizer, size_t sweepSize, const EpochConfiguration& epochConfig, const vector<float>& expectedOutput, size_t sequenceLength = 1)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
auto epochSize = epochConfig.m_totalEpochSizeInSamples;
auto mbSize = epochConfig.m_minibatchSizeInSamples;
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch);
BOOST_ASSERT(epochSize == expectedOutput.size());
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = 1;
epochConfiguration.m_workerRank = 0;
epochConfiguration.m_minibatchSizeInSamples = 0;
epochConfiguration.m_totalEpochSizeInSamples = data.size();
epochConfiguration.m_epochIndex = 0;
randomizer->StartEpoch(epochConfiguration);
randomizer.StartEpoch(epochConfig);
vector<float> expected { 6, 3, 1, 5, 9, 0, 4, 2, 7, 8 };
BOOST_CHECK_EQUAL(data.size(), expected.size());
vector<float> actual;
for (int i = 0; i < data.size() + 1; i++)
for (int totalSamplesRead = 0; totalSamplesRead < epochSize;)
{
Sequences sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1 - (i / data.size()));
if (i < data.size())
Sequences sequences = randomizer.GetNextSequences(mbSize, mbSize);
BOOST_ASSERT(sequences.m_data.size() == 1); // only one input stream
auto& stream = sequences.m_data[0];
auto numSampleRead = 0;
for (auto& sequence : stream)
{
auto& data2 = reinterpret_cast<DenseSequenceData&>(*sequences.m_data[0][0]);
BOOST_CHECK_EQUAL(data2.m_numberOfSamples, 1u);
actual.push_back(*((float*)data2.GetDataBuffer()));
auto numSamples = sequence->m_numberOfSamples;
numSampleRead += numSamples;
auto& data = reinterpret_cast<DenseSequenceData&>(*sequence);
actual.reserve(actual.size() + numSamples);
std::copy_n(((float*)data.GetDataBuffer()), numSamples, std::back_inserter(actual));
}
BOOST_CHECK_EQUAL(sequences.m_endOfEpoch, (data.size() <= i + 1));
auto expectedSize = std::min(epochSize - totalSamplesRead, mbSize);
if (!epochConfig.m_allowMinibatchesToCrossSweepBoundaries)
{
expectedSize = std::min(sweepSize - totalSamplesRead % sweepSize, expectedSize);
}
// at least one sequence is returned in case when mbSize < sequenceLength
expectedSize = std::max(expectedSize, sequenceLength);
BOOST_REQUIRE(numSampleRead <= std::max(mbSize, sequenceLength));
if (sequenceLength == 1)
BOOST_REQUIRE(numSampleRead == expectedSize);
else
BOOST_REQUIRE(expectedSize - numSampleRead < sequenceLength);
BOOST_REQUIRE(sequences.m_endOfEpoch == (totalSamplesRead + numSampleRead == epochSize));
BOOST_REQUIRE(sequences.m_endOfSweep == (totalSamplesRead / sweepSize != (totalSamplesRead + numSampleRead) / sweepSize));
totalSamplesRead += numSampleRead;
}
BOOST_CHECK_EQUAL_COLLECTIONS(expected.begin(), expected.end(),
for (int i = 0; i < 3; i++)
{
auto numSamples = i + 1;
Sequences sequences = randomizer.GetNextSequences(numSamples, numSamples);
BOOST_REQUIRE(sequences.m_data.size() == 0);
BOOST_REQUIRE(sequences.m_endOfEpoch == true);
BOOST_REQUIRE(sequences.m_endOfSweep == (epochSize % sweepSize == 0));
}
BOOST_REQUIRE_EQUAL_COLLECTIONS(expectedOutput.begin(), expectedOutput.end(),
actual.begin(), actual.end());
}
BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpoch)
void TestRandomization(EpochConfiguration& epochConfiguration, IDataDeserializerPtr deserializer, size_t sweepSize, const vector<float>& expectedRandomized, const vector<float>& expectedNotRandomized, size_t sequenceLength = 1)
{
BlockRandomizerOneEpochTest(false);
BlockRandomizerOneEpochTest(true);
BlockRandomizer randomizer1(0, SIZE_MAX, deserializer, /*prefetch =*/ false);
BlockRandomizer randomizer2(0, SIZE_MAX, deserializer, /*prefetch =*/ true);
NoRandomizer randomizer3(deserializer);
BlockRandomizer randomizer4(0, SIZE_MAX, deserializer, /*prefetch =*/ false, false, /*multithreadedGetNextSequences =*/ true);
BlockRandomizer randomizer5(0, SIZE_MAX, deserializer, /*prefetch =*/ true, false, /*multithreadedGetNextSequences =*/ true);
NoRandomizer randomizer6(deserializer, /*multithreadedGetNextSequences =*/ true);
epochConfiguration.m_numberOfWorkers = 1;
epochConfiguration.m_workerRank = 0;
epochConfiguration.m_totalEpochSizeInSamples = expectedRandomized.size();
for (int i = 1; i <= epochConfiguration.m_totalEpochSizeInSamples + 1; i++)
{
epochConfiguration.m_minibatchSizeInSamples = i;
OneEpochRandomizationTest(randomizer1, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
OneEpochRandomizationTest(randomizer2, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
OneEpochRandomizationTest(randomizer3, sweepSize, epochConfiguration, expectedNotRandomized, sequenceLength);
OneEpochRandomizationTest(randomizer4, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
OneEpochRandomizationTest(randomizer5, sweepSize, epochConfiguration, expectedRandomized, sequenceLength);
OneEpochRandomizationTest(randomizer6, sweepSize, epochConfiguration, expectedNotRandomized, sequenceLength);
}
}
BOOST_AUTO_TEST_CASE(TestRandomization_FirstEpoch)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8 };
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
EpochConfiguration epochConfiguration;
epochConfiguration.m_epochIndex = 0;
TestRandomization(epochConfiguration, mockDeserializer, data.size(), expected, data);
}
BOOST_AUTO_TEST_CASE(TestRandomization_SecondEpoch)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
vector<float> expected{ 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
EpochConfiguration epochConfiguration;
epochConfiguration.m_epochIndex = 1;
TestRandomization(epochConfiguration, mockDeserializer, data.size(), expected, data);
}
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
auto sweepSize = data.size();
data.reserve(2 * sweepSize);
std::copy_n(data.begin(), sweepSize, std::back_inserter(data));
EpochConfiguration epochConfiguration;
epochConfiguration.m_epochIndex = 0;
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expected, data);
}
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps_WithSequences)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
for (int seqLength = 2; seqLength <= 10; seqLength++)
{
vector<float> expectedRandomized;
vector<float> expectedNotRandomized;
for (auto f : expected) {
std::fill_n(back_inserter(expectedRandomized), seqLength, f);
}
for (int i = 0; i < 2 * data.size(); i++) {
std::fill_n(back_inserter(expectedNotRandomized), seqLength, data[i % data.size()]);
}
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data, seqLength);
auto sweepSize = data.size() * seqLength;
EpochConfiguration epochConfiguration;
epochConfiguration.m_epochIndex = 0;
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expectedRandomized, expectedNotRandomized, seqLength);
}
}
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps_AllowToCrossSweepBoundary)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
auto sweepSize = data.size();
data.reserve(2 * sweepSize);
std::copy_n(data.begin(), sweepSize, std::back_inserter(data));
EpochConfiguration epochConfiguration;
epochConfiguration.m_epochIndex = 0;
epochConfiguration.m_allowMinibatchesToCrossSweepBoundaries = true;
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expected, data);
}
BOOST_AUTO_TEST_CASE(TestRandomization_TwoSweeps_AllowToCrossSweepBoundary_WithSequences)
{
vector<float> data(10);
iota(data.begin(), data.end(), 0.0f);
vector<float> expected{ 6, 3, 1, 5, 9, 0, 4, 2, 7, 8, 3, 0, 8, 4, 7, 5, 2, 9, 1, 6 };
for (int seqLength = 2; seqLength <= 10; seqLength++)
{
vector<float> expectedRandomized;
vector<float> expectedNotRandomized;
for (auto f : expected) {
std::fill_n(back_inserter(expectedRandomized), seqLength, f);
}
for (int i = 0; i < 2 * data.size(); i++) {
std::fill_n(back_inserter(expectedNotRandomized), seqLength, data[i % data.size()]);
}
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data, seqLength);
auto sweepSize = data.size() * seqLength;
EpochConfiguration epochConfiguration;
epochConfiguration.m_epochIndex = 0;
epochConfiguration.m_allowMinibatchesToCrossSweepBoundaries = true;
TestRandomization(epochConfiguration, mockDeserializer, sweepSize, expectedRandomized, expectedNotRandomized, seqLength);
}
}
void BlockRandomizerOneEpochWithChunks1Test(bool prefetch)
@ -468,7 +602,7 @@ void BlockRandomizerOneEpochWithChunks1Test(bool prefetch)
iota(data.begin(), data.end(), 0.0f);
auto mockDeserializer = make_shared<MockDeserializer>(5, 2, data);
auto randomizer = make_shared<BlockRandomizer>(0, 4, mockDeserializer, prefetch);
auto randomizer = make_shared<BlockRandomizer>(0, 4, mockDeserializer, prefetch, false);
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = 1;
@ -510,7 +644,7 @@ void BlockRandomizerOneEpochWithChunks2Test(bool prefetch)
auto mockDeserializer = make_shared<MockDeserializer>(10, 2, data);
auto randomizer = make_shared<BlockRandomizer>(0, 18, mockDeserializer, prefetch);
auto randomizer = make_shared<BlockRandomizer>(0, 18, mockDeserializer, prefetch, false);
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = 1;
@ -548,65 +682,75 @@ BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochWithChunks2)
BlockRandomizerOneEpochWithChunks2Test(true);
}
void BlockRandomizerChaosMonkeyTest(bool prefetch)
void RandomizerChaosMonkeyTest(SequenceEnumerator& randomizer, size_t sweepSize, int seed)
{
const int sequenceLength = 3;
const int seed = 42;
const int numChunks = 100;
const int numSequencesPerChunk = 10;
const int windowSize = 18;
vector<float> data(numChunks * numSequencesPerChunk);
iota(data.begin(), data.end(), 0.0f);
std::mt19937 rng(seed);
boost::random::uniform_int_distribution<int> distr(1, 10);
auto mockDeserializer = make_shared<MockDeserializer>(numChunks, numSequencesPerChunk, data, sequenceLength);
auto randomizer = make_shared<BlockRandomizer>(0, windowSize, mockDeserializer, prefetch);
boost::random::uniform_int_distribution<int> distr(1, 100);
for (int t = 0; t < 100; t++)
{
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = distr(rng);
do
{
epochConfiguration.m_workerRank = distr(rng) - 1;
}
while (epochConfiguration.m_numberOfWorkers <= epochConfiguration.m_workerRank);
epochConfiguration.m_workerRank = distr(rng) % epochConfiguration.m_numberOfWorkers;
epochConfiguration.m_minibatchSizeInSamples = 0; // don't care
epochConfiguration.m_totalEpochSizeInSamples = data.size() / distr(rng);
epochConfiguration.m_totalEpochSizeInSamples = sweepSize * distr(rng) / distr(rng);
epochConfiguration.m_epochIndex = distr(rng);
randomizer->StartEpoch(epochConfiguration);
epochConfiguration.m_allowMinibatchesToCrossSweepBoundaries = (distr(rng) % 2 == 0);
randomizer.StartEpoch(epochConfiguration);
auto epochStart = epochConfiguration.m_epochIndex * epochConfiguration.m_totalEpochSizeInSamples;
auto epochEnd = epochStart + epochConfiguration.m_totalEpochSizeInSamples;
auto numSweeps = epochEnd / sweepSize - epochStart / sweepSize;
auto sweepCount = 0;
int samplesToGet = 0;
for (int i = 0; i < epochConfiguration.m_totalEpochSizeInSamples + 1; i += samplesToGet)
for (;;)
{
samplesToGet = distr(rng);
Sequences sequences = randomizer->GetNextSequences(samplesToGet, samplesToGet);
Sequences sequences = randomizer.GetNextSequences(samplesToGet, samplesToGet);
if (sequences.m_endOfSweep)
sweepCount++;
// In case end of epoch/decimation/single sequence -> skip the mbSize check.
if (sequences.m_endOfEpoch || sequences.m_data.empty() || sequences.m_data.front().size() < 2)
if (!(sequences.m_data.empty() || sequences.m_data.size() == 1))
{
continue;
// Check that we do not exceed the minibatch size.
size_t count = 0;
for (const auto& sequence : sequences.m_data.front())
{
count += sequence->m_numberOfSamples;
}
BOOST_REQUIRE_LE(count, samplesToGet);
}
// Check that we do not exceed the minibatch size.
size_t count = 0;
for (const auto& sequence : sequences.m_data.front())
{
count += sequence->m_numberOfSamples;
}
BOOST_CHECK_LE(count, samplesToGet);
if (sequences.m_endOfEpoch)
break;
}
BOOST_REQUIRE(sweepCount == numSweeps);
}
}
BOOST_AUTO_TEST_CASE(BlockRandomizerChaosMonkey)
BOOST_AUTO_TEST_CASE(RandomizerChaosMonkey)
{
BlockRandomizerChaosMonkeyTest(false);
BlockRandomizerChaosMonkeyTest(true);
const int sequenceLength = 3;
const int numChunks = 100;
const int numSequencesPerChunk = 10;
const int windowSize = 18;
vector<float> data(numChunks * numSequencesPerChunk);
iota(data.begin(), data.end(), 0.0f);
auto mockDeserializer = make_shared<MockDeserializer>(numChunks, numSequencesPerChunk, data, sequenceLength);
BlockRandomizer blockRandomizerNoPrefetch(0, windowSize, mockDeserializer, false, false);
BlockRandomizer blockRandomizerWithPrefetch(0, windowSize, mockDeserializer, true, false);
NoRandomizer norandomizer(mockDeserializer);
auto sweepSize = data.size() * sequenceLength;
RandomizerChaosMonkeyTest(blockRandomizerNoPrefetch, sweepSize, 42);
RandomizerChaosMonkeyTest(blockRandomizerWithPrefetch, sweepSize, 43);
RandomizerChaosMonkeyTest(norandomizer, sweepSize, 44);
}
void BlockRandomizerOneEpochLegacyRandomizationTest(bool prefetch)
@ -618,7 +762,8 @@ void BlockRandomizerOneEpochLegacyRandomizationTest(bool prefetch)
auto randomizer = make_shared<BlockRandomizer>(0,
SIZE_MAX,
mockDeserializer,
prefetch);
prefetch,
true);
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = 1;
@ -691,6 +836,48 @@ BOOST_AUTO_TEST_CASE(NoRandomizerOneEpoch)
actual.begin(), actual.end());
}
BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers)
{
size_t chunkSizeInSamples = 10000;
size_t sweepNumberOfSamples = 500000;
uint32_t maxSequenceLength = 300;
size_t randomizationWindow = chunkSizeInSamples * 5;
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
auto blockRandomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, false);
auto noRandomizer = make_shared<NoRandomizer>(deserializer, false);
auto test = [](SequenceEnumeratorPtr r, size_t epochSize)
{
auto firstEpoch = ReadFullEpoch(r, epochSize, 0);
auto firstCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(firstCursor, firstEpoch.size());
auto secondEpoch = ReadFullEpoch(r, epochSize, 1);
auto secondCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(secondCursor - firstCursor, secondEpoch.size());
auto thirdEpoch = ReadFullEpoch(r, epochSize, 2);
auto thirdCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(thirdCursor - secondCursor, thirdEpoch.size());
auto anotherSecondEpoch = ReadFullEpoch(r, epochSize, 1);
auto anotherSecondCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(anotherSecondCursor, secondCursor);
};
// Inside sweep
size_t epochSize = 50000;
test(blockRandomizer, epochSize);
test(noRandomizer, epochSize);
// Between sweeps
epochSize = (size_t)(sweepNumberOfSamples / 1.5);
test(blockRandomizer, epochSize);
test(noRandomizer, epochSize);
}
BOOST_AUTO_TEST_CASE(DefaultCorpusDescriptor)
{
const int seed = 13;
@ -713,8 +900,8 @@ BOOST_AUTO_TEST_CASE(NumericCorpusDescriptor)
CorpusDescriptor corpus(true);
for (int i = 0; i < 10; ++i)
{
auto value = distr(rng);
BOOST_CHECK_EQUAL(value, corpus.KeyToId(std::to_string(value)));
auto value = distr(rng);
BOOST_CHECK_EQUAL(value, corpus.KeyToId(std::to_string(value)));
}
BOOST_CHECK_EXCEPTION(
corpus.KeyToId("not a number"),

Просмотреть файл

@ -166,7 +166,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer->TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo].m_data }, { labelsVar, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo] }, { labelsVar, minibatchData[labelStreamInfo] } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
}
}

Просмотреть файл

@ -106,10 +106,6 @@ void TestRMSPropLearner(size_t numParameters, size_t numMinibatches, const Devic
void TestTrainingParametersSchedule()
{
VerifyException([]() {
LearningRatePerMinibatchSchedule({ 3.0, 2.0, 1.0 }, LearningRateSchedule::EntireSweep);
}, "Was able to create not-yet-implemented sweep-based schedule.");
LearningRatePerSampleSchedule schedule1 = 0.5;
assert(schedule1.Unit() == LearningRateSchedule::UnitType::Sample);
assert(schedule1[0] == 0.5);
@ -229,11 +225,76 @@ void TestTrainingParametersSchedule()
}
void TestSweepBasedSchedule()
{
DeviceDescriptor device = DeviceDescriptor::CPUDevice();
auto schedule = LearningRatePerSampleSchedule({ 1, 2, 3, 4, 5 }, LearningRateSchedule::FullDataSweep);
auto learner1 = SGDLearner({}, schedule);
assert(1 == learner1->LearningRate());
for (auto i : {2, 3, 4, 5 })
{
std::unordered_map<Parameter, NDArrayViewPtr> gradients {};
learner1->Update(gradients, 1, true);
assert(i == learner1->LearningRate());
}
const size_t inputDim = 2;
const size_t numOutputClasses = 2;
auto minibatchSource = TextFormatMinibatchSource(L"SimpleDataTest_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
auto sweepSize = 603; // == wc -l SimpleDataTest_cntk_text.txt
auto minibatchSize = 400;
auto featureStreamInfo = minibatchSource->StreamInfo(L"features");
auto labelStreamInfo = minibatchSource->StreamInfo(L"labels");
auto input = InputVariable({ inputDim }, DataType::Float, L"features");
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");
auto classifierOutput = FullyConnectedLinearLayer(input, numOutputClasses, device);
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
auto learner2 = SGDLearner(classifierOutput->Parameters(), schedule);
auto trainer = CreateTrainer(classifierOutput, trainingLoss, prediction, { learner2 });
for (auto i = 0; i <= 4000; i+= minibatchSize)
{
auto sweepIndex1 = i / sweepSize;
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
if (minibatchData[featureStreamInfo].sweepEnd != minibatchData[labelStreamInfo].sweepEnd) {
ReportFailure("TestSweepBasedSchedule failed: "
"different streams have different end of sweep flag values.");
}
auto sweepIndex2 = (i + minibatchSize) / sweepSize;
if ((sweepIndex1 != sweepIndex2) != minibatchData[labelStreamInfo].sweepEnd) {
ReportFailure("TestSweepBasedSchedule failed: "
"end of sweep flag value is different from expected.");
}
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto expectedLR = std::min((sweepIndex2 + 1), 5);
if (expectedLR != learner2->LearningRate()) {
ReportFailure("TestSweepBasedSchedule failed: "
"learning rate value is different from expected.");
}
}
}
void LearnerTests()
{
fprintf(stderr, "\nLearnerTests..\n");
TestTrainingParametersSchedule();
TestSweepBasedSchedule();
vector<DeviceDescriptor> devices{DeviceDescriptor::CPUDevice()};

Просмотреть файл

@ -159,11 +159,11 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
auto minibatchData = minibatchSource->GetNextMinibatch(0, minibatchSize, 1, 0);
auto minibatchData2 = minibatchSource2->GetNextMinibatch(0, minibatchSize, 1, 0);
if (minibatchData[featureStreamInfo].m_numSamples != minibatchData2[featureStreamInfo].m_numSamples)
if (minibatchData[featureStreamInfo].numberOfSamples != minibatchData2[featureStreamInfo].numberOfSamples)
ReportFailure("Data does not match, reads are not deterministic!!!");
// Because they are supposed to read the same data - adding it only once.
totalSamples += minibatchData[featureStreamInfo].m_numSamples;
totalSamples += minibatchData[featureStreamInfo].numberOfSamples;
}
else
{
@ -179,27 +179,27 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
// Update the counter
size_t accumulative = 0;
if (!minibatchData.empty())
accumulative += minibatchData[featureStreamInfo].m_numSamples;
accumulative += minibatchData[featureStreamInfo].numberOfSamples;
if (!minibatchData2.empty())
accumulative += minibatchData2[featureStreamInfo].m_numSamples;
accumulative += minibatchData2[featureStreamInfo].numberOfSamples;
totalSamples += accumulative;
if (expectNoData) // second worker does not have any data.
{
if (minibatchData[featureStreamInfo].m_numSamples != minibatchSize/2 && totalSamples != numberOfSamplesInSweep)
if (minibatchData[featureStreamInfo].numberOfSamples != minibatchSize/2 && totalSamples != numberOfSamplesInSweep)
ReportFailure("TestMinibatchSourceWarmStart failed because data did not match."
"Expected minibatch size '%d', acutal '%d'. Total number of sample '%d', sweep '%d'.",
(int)minibatchSize,
(int)minibatchData[featureStreamInfo].m_numSamples,
(int)minibatchData[featureStreamInfo].numberOfSamples,
(int)totalSamples,
(int)numberOfSamplesInSweep);
}
else
{
if (accumulative != minibatchSize &&
minibatchData[featureStreamInfo].m_numSamples != minibatchSize / 2 &&
minibatchData2[featureStreamInfo].m_numSamples != minibatchSize / 2 &&
minibatchData[featureStreamInfo].numberOfSamples != minibatchSize / 2 &&
minibatchData2[featureStreamInfo].numberOfSamples != minibatchSize / 2 &&
totalSamples != numberOfSamplesInSweep)
ReportFailure("TestMinibatchSourceWarmStart failed because data did not match."
"Expected minibatch size '%d', acutal '%d'. Total number of sample '%d', sweep '%d'.",
@ -217,8 +217,81 @@ void TestMinibatchSourceWarmStart(size_t minibatchSize, size_t warmStartSamples,
(int)totalSamples);
}
void TestEndOfSweepFlag(size_t maxSamples, size_t mbSize, bool randomize)
{
const size_t sweepSize = 603;
auto ctfInput = L"SimpleDataTest_cntk_text.txt";
std::vector<StreamConfiguration> streamConfig { { L"features", 2 } };
auto cpuDevice = DeviceDescriptor::CPUDevice();
auto src = TextFormatMinibatchSource(ctfInput, streamConfig, maxSamples, randomize);
maxSamples = (maxSamples == MinibatchSource::FullDataSweep) ? sweepSize : maxSamples;
bool reachedEndOfEpoch = false;
size_t sampleCount = 0;
while (sampleCount < maxSamples)
{
auto& dataMap = src->GetNextMinibatch(mbSize, cpuDevice);
if (dataMap.size() != streamConfig.size())
{
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: "
"unexpected number of streams in the minibatch (%zu).", dataMap.size());
}
for (auto& streamData : dataMap)
{
auto numSamplesInMinibatch = streamData.second.numberOfSamples;
bool expectedEndOfSweep = ((sampleCount + numSamplesInMinibatch) % sweepSize) == 0;
expectedEndOfSweep |= ((sampleCount) / sweepSize) < ((sampleCount + numSamplesInMinibatch) / sweepSize);
reachedEndOfEpoch = (sampleCount + mbSize >= maxSamples);
size_t expectedNumSamples = reachedEndOfEpoch ? (maxSamples - sampleCount) : mbSize;
if (streamData.second.sweepEnd != expectedEndOfSweep)
{
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: end of sweep flag is not set.");
}
if (streamData.second.numberOfSamples != expectedNumSamples)
{
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: "
"unexpected number of samples in the minibatch (%zu).", streamData.second.numberOfSamples);
}
if (streamData.second.numberOfSequences != expectedNumSamples)
{
ReportFailure("TestThatEndOfSweepFlagIsSetCorrectly failed: "
"unexpected number of sequences in the minibatch (%zu).", streamData.second.numberOfSequences);
}
}
sampleCount += mbSize;
}
auto& emptyDataMap = src->GetNextMinibatch(mbSize, cpuDevice);
assert(emptyDataMap.empty());
}
void TestThatEndOfSweepFlagIsSetCorrectly()
{
for (auto randomize : { false, true })
{
TestEndOfSweepFlag(MinibatchSource::FullDataSweep, 603, randomize);
TestEndOfSweepFlag(MinibatchSource::FullDataSweep, 1000, randomize);
TestEndOfSweepFlag(MinibatchSource::FullDataSweep, 100, randomize);
TestEndOfSweepFlag(100, 30, randomize);
TestEndOfSweepFlag(2000, 500, randomize);
TestEndOfSweepFlag(2412, 301, randomize);
}
}
void MinibatchSourceTests()
{
TestThatEndOfSweepFlagIsSetCorrectly();
// Test no-randomize minibatch source with small data chunks
TestMinibatchSourceWarmStart(64, 128, false, 1024);
TestMinibatchSourceWarmStart(64, 0, false, 1024);

Просмотреть файл

@ -209,7 +209,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
if (minibatchData.empty())
break;
trainer->TrainMinibatch({ { rawInput, minibatchData[rawInputStreamInfo].m_data }, { rawLabels, minibatchData[rawLabelsStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { rawInput, minibatchData[rawInputStreamInfo] }, { rawLabels, minibatchData[rawLabelsStreamInfo] } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
if ((i + 1) == numMinibatchesToCheckpointAfter)
@ -222,7 +222,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
if ((i % decodingFrequency) == 0)
{
std::unordered_map<Variable, ValuePtr> outputs = { { decodingFunction, nullptr }};
decodingFunction->Forward({ { decodingFunction->Arguments()[0], minibatchData[rawInputStreamInfo].m_data }, { decodingFunction->Arguments()[1], minibatchData[rawLabelsStreamInfo].m_data } },
decodingFunction->Forward({ { decodingFunction->Arguments()[0], minibatchData[rawInputStreamInfo].data }, { decodingFunction->Arguments()[1], minibatchData[rawLabelsStreamInfo].data } },
outputs,
device);
}

Просмотреть файл

@ -59,7 +59,7 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool useSparseLa
if (minibatchData.empty())
break;
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
}
}
@ -87,37 +87,37 @@ void TestLearningRateControl(const DeviceDescriptor& device)
const size_t minibatchSize = 200;
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
auto actualMBSize = minibatchData[labelStreamInfo].numberOfSamples;
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
auto trainer = CreateTrainer(classifierOutput, trainingLoss, prediction, { learner });
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
const wchar_t* modelFile = L"seq2seq.model";
trainer->SaveCheckpoint(modelFile);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto MB2Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto MB3Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
trainer->RestoreFromCheckpoint(modelFile);
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto postRestoreMB3Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
@ -128,13 +128,13 @@ void TestLearningRateControl(const DeviceDescriptor& device)
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer->SaveCheckpoint(modelFile);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
postRestoreMB3Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
@ -143,13 +143,13 @@ void TestLearningRateControl(const DeviceDescriptor& device)
trainer->RestoreFromCheckpoint(modelFile);
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
postRestoreMB3Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");

Просмотреть файл

@ -434,7 +434,7 @@ void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const
Dictionary model = classifierOutput1->Serialize();
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto classifierOutput2 = Function::Deserialize(model, device);
@ -458,8 +458,8 @@ void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const
for (int j = 0; j < 3; ++j)
{
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer2->TrainMinibatch({ { classifierOutput3->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer1->TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
trainer2->TrainMinibatch({ { classifierOutput3->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
double mbLoss1 = trainer1->PreviousMinibatchLossAverage();
double mbLoss2 = trainer2->PreviousMinibatchLossAverage();
@ -503,7 +503,7 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
const size_t minibatchSize = 50;
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
auto actualMBSize = minibatchData[labelStreamInfo].numberOfSamples;
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize);
@ -522,14 +522,14 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original.");
}
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
if (AreEqual(function1, function2))
{
throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained.");
}
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
if (!AreEqual(function1, function2))
{
@ -548,8 +548,8 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
for (int j = 0; j < 3; ++j)
{
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer1->TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
trainer2->TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
double mbLoss1 = trainer1->PreviousMinibatchLossAverage();
double mbLoss2 = trainer2->PreviousMinibatchLossAverage();
@ -638,36 +638,36 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
const size_t minibatchSize = 50;
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
auto actualMBSize = minibatchData[labelStreamInfo].numberOfSamples;
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
auto trainer = CreateTrainer(classifierOutput, trainingLoss, prediction, { learner });
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
const wchar_t* modelFile = L"seq2seq.legacy.model";
Internal::SaveAsLegacyModel(classifierOutput, modelFile);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto MB2Loss = trainer->PreviousMinibatchLossAverage();
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
classifierOutput->RestoreModel(modelFile);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
auto postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
classifierOutput->RestoreModel(modelFile);
Internal::SaveAsLegacyModel(classifierOutput, modelFile);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
classifierOutput->RestoreModel(modelFile);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
postRestoreMB2Loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
@ -684,7 +684,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
{
trainer->SaveCheckpoint(L"trainer.checkpoint" + std::to_wstring(i));
Internal::SaveAsLegacyModel(classifierOutput, modelFile + std::to_wstring(i));
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
expectedLoss.push_back(trainer->PreviousMinibatchLossAverage());
}
@ -692,7 +692,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
{
trainer->RestoreFromCheckpoint(L"trainer.checkpoint" + std::to_wstring(i));
classifierOutput->RestoreModel(modelFile + std::to_wstring(i));
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
double loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
}
@ -767,20 +767,20 @@ void TestCheckpointingWithStatefulNodes(const DeviceDescriptor& device)
auto featureStreamInfo = minibatchSource->StreamInfo(features);
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
vector<double> expectedLoss;
for (int i = 0; i < epochSize / minibatchSize; i++)
{
trainer->SaveCheckpoint(L"stateful_nodes.model" + std::to_wstring(i));
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
expectedLoss.push_back(trainer->PreviousMinibatchLossAverage());
}
for (int i = 0; i < epochSize / minibatchSize; i++)
{
trainer->RestoreFromCheckpoint(L"stateful_nodes.model" + std::to_wstring(i));
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
double loss = trainer->PreviousMinibatchLossAverage();
FloatingPointCompare(loss, expectedLoss[i], "Post checkpoint restoration training loss does not match expectation");
}

Просмотреть файл

@ -67,7 +67,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
if ((i % trainingCheckpointFrequency) == (trainingCheckpointFrequency - 1))
@ -128,7 +128,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
}
}

Просмотреть файл

@ -125,11 +125,11 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
break;
// Make sure our truncation length setting was honored
auto actualMaxSequenceLength = minibatchData[featureStreamInfo].m_data->Shape()[featureStreamInfo.m_sampleLayout.Rank()];
auto actualMaxSequenceLength = minibatchData[featureStreamInfo].data->Shape()[featureStreamInfo.m_sampleLayout.Rank()];
if (actualMaxSequenceLength != truncationLength)
ReportFailure("Actual max sequence length (%d) in minibatch data does not equal specified truncation length (%d)", (int)actualMaxSequenceLength, (int)truncationLength);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
trainer->TrainMinibatch({ { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
}
}

Просмотреть файл

@ -24,6 +24,15 @@
%rename(momentum_as_time_constant_schedule) CNTK::MomentumAsTimeConstantSchedule;
// renaming overloads for TrainMinibatch and TestMinibatch that take a map
// of Variables and MinibatchData as their first parameter. If this is not done,
// the overloads that are legal in C++ will be shadowed and ignored by SWIG.
// The naming here is somewhat cumbersome, but it's only intended for internal
// consumption in proxy objects.
%rename(train_minibatch_overload_for_minibatchdata) CNTK::Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>&, const DeviceDescriptor& = DeviceDescriptor::UseDefaultDevice());
%rename(train_minibatch_overload_for_minibatchdata) CNTK::Trainer::TrainMinibatch(const std::unordered_map<Variable, MinibatchData>&, std::unordered_map<Variable, ValuePtr>&, const DeviceDescriptor& = DeviceDescriptor::UseDefaultDevice());
%rename(test_minibatch_overload_for_minibatchdata) CNTK::Trainer::TestMinibatch(const std::unordered_map<Variable, MinibatchData>&, const DeviceDescriptor& = DeviceDescriptor::UseDefaultDevice());
%rename(l1_regularization_weight) CNTK::AdditionalLearningOptions::l1RegularizationWeight;
%rename(l2_regularization_weight) CNTK::AdditionalLearningOptions::l2RegularizationWeight;
%rename(ndcg_at_1) CNTK::NDCGAt1;
@ -1087,6 +1096,7 @@ public:
%unordered_map_conversion(CNTK::Parameter, const CNTK::NDArrayViewPtr, SWIGTYPE_p_CNTK__Parameter, SWIGTYPE_p_std__shared_ptrT_CNTK__NDArrayView_t)
%unordered_map_conversion(CNTK::Parameter, CNTK::NDArrayViewPtr, SWIGTYPE_p_CNTK__Parameter, SWIGTYPE_p_std__shared_ptrT_CNTK__NDArrayView_t)
%unordered_map_conversion(CNTK::Variable, CNTK::StreamInformation, SWIGTYPE_p_CNTK__Variable, SWIGTYPE_p_CNTK__StreamInformation)
%unordered_map_conversion(CNTK::Variable, CNTK::MinibatchData, SWIGTYPE_p_CNTK__Variable, SWIGTYPE_p_CNTK__MinibatchData)
%unordered_map_ref_conversion(CNTK::StreamInformation, SWIGTYPE_p_CNTK__StreamInformation, CNTK::MinibatchData, SWIGTYPE_p_CNTK__MinibatchData);
%unordered_map_ref_conversion(CNTK::Parameter, SWIGTYPE_p_CNTK__Parameter, CNTK::NDArrayViewPtr, SWIGTYPE_p_std__shared_ptrT_CNTK__NDArrayView);

Просмотреть файл

@ -27,28 +27,28 @@ class MinibatchData(cntk_py.MinibatchData, ArrayMixin):
'''
The number of sequences in this minibatch
'''
return self.m_num_sequences
return self.number_of_sequences
@property
def num_samples(self):
'''
The number of samples in this minibatch
'''
return self.m_num_samples
return self.number_of_samples
@property
def value(self):
'''
The value of the minibatch as a NumPy array.
'''
return value_to_seq(self.m_data)
return value_to_seq(self.data)
@property
def shape(self):
'''
The shape of the data in this minibatch as tuple.
'''
return self.m_data.shape().dimensions()
return self.data.shape().dimensions()
@property
def mask(self):
@ -57,14 +57,23 @@ class MinibatchData(cntk_py.MinibatchData, ArrayMixin):
sequence, `1` marks a sequence element as valid, and `0` marks it as
invalid.
'''
return self.m_data.mask().to_ndarray()
return self.data.mask().to_ndarray()
@property
def end_of_sweep(self):
'''
Indicates whether the data in this minibatch is comes from a sweep end
or crosses a sweep boundary (and as a result includes data from
different sweeps).
'''
return self.sweep_end
@property
def is_sparse(self):
'''
Whether the data in this minibatch is sparse.
'''
return self.m_data.is_sparse()
return self.data.is_sparse()
def __len__(self):
return self.num_sequences

Просмотреть файл

@ -9,15 +9,13 @@ import os
import numpy as np
import pytest
from cntk.io import _is_tensor, sequence_to_cntk_text_format
from cntk.io import *
abs_path = os.path.dirname(os.path.abspath(__file__))
AA = np.asarray
def test_text_format(tmpdir):
from cntk.io import CTFDeserializer, MinibatchSource, StreamDef, StreamDefs
mbdata = r'''0 |x 560:1 |y 1 0 0 0 0
0 |x 0:1
0 |x 0:1
@ -48,6 +46,9 @@ def test_text_format(tmpdir):
features = mb[features_si]
# 2 samples, max seq len 4, 1000 dim
assert features.shape == (2, 4, input_dim)
assert features.end_of_sweep
assert features.num_sequences == 2
assert features.num_samples == 7
assert features.is_sparse
# TODO features is sparse and cannot be accessed right now:
# *** RuntimeError: DataBuffer/WritableDataBuffer methods can only be called for NDArrayiew objects with dense storage format
@ -58,6 +59,9 @@ def test_text_format(tmpdir):
labels = mb[labels_si]
# 2 samples, max seq len 1, 5 dim
assert labels.shape == (2, 1, num_output_classes)
assert labels.end_of_sweep
assert labels.num_sequences == 2
assert labels.num_samples == 2
assert not labels.is_sparse
label_data = np.asarray(labels)
@ -67,8 +71,16 @@ def test_text_format(tmpdir):
[[ 0., 1., 0., 0., 0.]]
]))
mb = mb_source.next_minibatch(1)
features = mb[features_si]
labels = mb[labels_si]
assert not features.end_of_sweep
assert not labels.end_of_sweep
assert features.num_samples < 7
assert labels.num_samples == 1
def test_image():
from cntk.io import ReaderConfig, ImageDeserializer
map_file = "input.txt"
mean_file = "mean.txt"
epoch_size = 150
@ -153,7 +165,7 @@ def test_image():
assert set(sis.keys()) == { feature_name, label_name }
'''
def test_minibatch(tmpdir):
def test_full_sweep_minibatch(tmpdir):
mbdata = r'''0 |S0 0 |S1 0
0 |S0 1 |S1 1
@ -168,10 +180,10 @@ def test_minibatch(tmpdir):
with open(tmpfile, 'w') as f:
f.write(mbdata)
from cntk.io import CTFDeserializer, MinibatchSource, StreamDef, StreamDefs
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
features = StreamDef(field='S0', shape=1),
labels = StreamDef(field='S1', shape=1))))
labels = StreamDef(field='S1', shape=1))),
randomize=False, epoch_size=FULL_DATA_SWEEP)
features_si = mb_source.stream_info('features')
labels_si = mb_source.stream_info('labels')
@ -181,6 +193,7 @@ def test_minibatch(tmpdir):
assert mb[labels_si].num_sequences == 2
features = mb[features_si]
assert features.end_of_sweep
assert len(features.value) == 2
expected_features = \
[
@ -196,6 +209,7 @@ def test_minibatch(tmpdir):
[2, 1, 1, 0]])
labels = mb[labels_si]
assert labels.end_of_sweep
assert len(labels.value) == 2
expected_labels = \
[
@ -209,6 +223,46 @@ def test_minibatch(tmpdir):
[[2, 1, 1],
[2, 1, 0]])
def test_large_minibatch(tmpdir):
mbdata = r'''0 |S0 0 |S1 0
0 |S0 1 |S1 1
0 |S0 2
0 |S0 3 |S1 3
0 |S0 4
0 |S0 5 |S1 1
0 |S0 6 |S1 2
'''
tmpfile = str(tmpdir/'mbtest.txt')
with open(tmpfile, 'w') as f:
f.write(mbdata)
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
features = StreamDef(field='S0', shape=1),
labels = StreamDef(field='S1', shape=1))),
randomize=False)
features_si = mb_source.stream_info('features')
labels_si = mb_source.stream_info('labels')
mb = mb_source.next_minibatch(1000)
features = mb[features_si]
labels = mb[labels_si]
# Actually, the minibatch spans over multiple sweeps,
# not sure if this is an artificial situation, but
# maybe instead of a boolean flag we should indicate
# the largest sweep index the data was taken from.
assert features.end_of_sweep
assert labels.end_of_sweep
assert features.num_samples == 1000 - 1000 % 7
assert labels.num_samples == 5 * (1000 // 7)
assert mb[features_si].num_sequences == (1000 // 7)
assert mb[labels_si].num_sequences == (1000 // 7)
@pytest.mark.parametrize("idx, alias_tensor_map, expected", [
(0, {'A': [object()]}, ValueError),
@ -250,4 +304,5 @@ def test_sequence_conversion_dense(idx, alias_tensor_map, expected):
([AA([1, 2]), AA([])], False),
])
def test_is_tensor(data, expected):
from cntk.io import _is_tensor
assert _is_tensor(data) == expected

Просмотреть файл

@ -130,7 +130,7 @@ class Learner(cntk_py.Learner):
return super(Learner, self).learning_rate()
@typemap
def training_parameter_schedule(schedule, unit, epoch_size=1):
def training_parameter_schedule(schedule, unit, epoch_size=None):
'''
Create a training parameter schedule containing either per-sample (default)
or per-minibatch values.
@ -160,8 +160,13 @@ def training_parameter_schedule(schedule, unit, epoch_size=1):
unit (:class:`UnitType`): one of two
* ``sample``: the returned schedule contains per-sample values
* ``minibatch``: the returned schedule contains per-minibatch values.
epoch_size (int): number of samples as a scheduling unit. Parameters in
the schedule change their values every ``epoch_size`` samples.
epoch_size (optional, int): number of samples as a scheduling unit.
Parameters in the schedule change their values every ``epoch_size``
samples. If no ``epoch_size`` is provided, this parameter is substituted
by the size of the full data sweep, in which case the scheduling unit is
the entire data sweep (as indicated by the MinibatchSource) and parameters
change their values on the sweep-by-sweep basis specified by the
``schedule``.
Returns:
training parameter schedule
@ -177,7 +182,7 @@ def training_parameter_schedule(schedule, unit, epoch_size=1):
return schedule
if isinstance(schedule, (int, float)):
if epoch_size != 1:
if epoch_size is not None:
raise ValueError('when providing the schedule as a number,'
' epoch_size is ignored')
if UnitType(unit) is UnitType.sample:
@ -185,16 +190,18 @@ def training_parameter_schedule(schedule, unit, epoch_size=1):
else:
return cntk_py.training_parameter_per_minibatch_schedule(schedule)
args = [schedule] if epoch_size is None else [schedule, epoch_size]
if isinstance(schedule, list):
if UnitType(unit) is UnitType.sample:
return cntk_py.training_parameter_per_sample_schedule(schedule, epoch_size)
return cntk_py.training_parameter_per_sample_schedule(*args)
else:
return cntk_py.training_parameter_per_minibatch_schedule(schedule, epoch_size)
return cntk_py.training_parameter_per_minibatch_schedule(*args)
raise ValueError('schedule must be either a float or a list, not %s'%type(schedule))
@typemap
def learning_rate_schedule(lr, unit, epoch_size=1):
def learning_rate_schedule(lr, unit, epoch_size=None):
'''
Create a learning rate schedule (using the same semantics as
:func:`training_parameter_schedule`).
@ -216,7 +223,7 @@ def learning_rate_schedule(lr, unit, epoch_size=1):
return training_parameter_schedule(lr, unit, epoch_size)
@typemap
def momentum_schedule(momentum, epoch_size=1):
def momentum_schedule(momentum, epoch_size=None):
'''
Create a per-minibatch momentum schedule (using the same semantics as
:func:`training_parameter_schedule` with the `unit=UnitType.minibatch`).
@ -253,7 +260,7 @@ def momentum_schedule(momentum, epoch_size=1):
return training_parameter_schedule(momentum, UnitType.minibatch, epoch_size)
@typemap
def momentum_as_time_constant_schedule(momentum, epoch_size=1):
def momentum_as_time_constant_schedule(momentum, epoch_size=None):
'''
Create a momentum schedule in a minibatch-size agnostic way
(using the same semantics as :func:`training_parameter_schedule`
@ -288,9 +295,14 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
return momentum
if isinstance(momentum, (int, float)):
if epoch_size is not None:
raise ValueError('when providing the schedule as a number,'
' epoch_size is ignored')
return cntk_py.momentum_as_time_constant_schedule(momentum)
if isinstance(momentum, list):
return cntk_py.momentum_as_time_constant_schedule(momentum, epoch_size)
args = [momentum] if epoch_size is None else [momentum, epoch_size]
return cntk_py.momentum_as_time_constant_schedule(*args)
raise ValueError('momentum must be either a float or a list, not %s'%type(momentum))

Просмотреть файл

@ -166,15 +166,16 @@ class ArrayMixin(object):
@property
def __array_interface__(self):
try:
# This first check is for a Value object. Trying with self.to_ndarray first would lead to
# a infinite recursion, since Value has a to_ndarray method
np_array = self.data().to_ndarray()
# This checks for a MinibatchData object.
np_array = self.value
except AttributeError:
try:
np_array = self.to_ndarray()
# This checks for a Value object. Trying with self.to_ndarray first would lead to
# a infinite recursion, since Value has a to_ndarray method
np_array = self.data().to_ndarray()
except AttributeError:
try:
np_array = self.value
np_array = self.to_ndarray()
except AttributeError:
# Ideally an exception would be raised here, but getattr would swallow it
# so we return None

Просмотреть файл

@ -96,7 +96,7 @@ def test_learner_update():
w = parameter(shape=(1,), init=w_init)
res = i * w
learner = sgd(res.parameters, lr=learning_rate_schedule([0.1]*50 + [0.2]*50, UnitType.sample))
learner = sgd(res.parameters, lr=learning_rate_schedule([0.1]*50 + [0.2]*50, UnitType.sample, 1))
assert learner.learning_rate() == 0.1
x = learner.update({w: np.asarray([[2.]], dtype=np.float32)}, 100)
assert learner.learning_rate() == 0.2
@ -110,3 +110,66 @@ def test_training_parameter_schedule():
training_parameter_schedule(0.01, unit='not_supported')
with pytest.raises(ValueError):
training_parameter_schedule(0.01, unit=5)
def test_sweep_based_schedule(tmpdir, device_id):
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
from .. import cross_entropy_with_softmax, classification_error, plus, reduce_sum
from ..trainer import Trainer
input_dim = 69
ctf_data = '''\
0 |S0 3:1 |S1 3:1 |# <s>
0 |S0 4:1 |# A |S1 32:1 |# ~AH
0 |S0 5:1 |# B |S1 36:1 |# ~B
0 |S0 4:1 |# A |S1 31:1 |# ~AE
0 |S0 7:1 |# D |S1 38:1 |# ~D
0 |S0 12:1 |# I |S1 47:1 |# ~IY
0 |S0 1:1 |# </s> |S1 1:1 |# </s>
2 |S0 60:1 |# <s> |S1 3:1 |# <s>
2 |S0 61:1 |# A |S1 32:1 |# ~AH
'''
ctf_file = str(tmpdir/'2seqtest.txt')
with open(ctf_file, 'w') as f:
f.write(ctf_data)
mbs = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features = StreamDef(field='S0', shape=input_dim, is_sparse=True),
labels = StreamDef(field='S1', shape=input_dim, is_sparse=True)
)), randomize=False)
in1 = input_variable(shape=(input_dim,))
labels = input_variable(shape=(input_dim,))
p = parameter(shape=(input_dim,), init=10)
z = plus(in1, reduce_sum(p), name='z')
ce = cross_entropy_with_softmax(z, labels)
errs = classification_error(z, labels)
lr_per_sample = learning_rate_schedule([0.3, 0.2, 0.1, 0.0], UnitType.sample)
learner = sgd(z.parameters, lr_per_sample)
trainer = Trainer(z, ce, errs, [learner])
input_map = {
in1 : mbs.streams.features,
labels : mbs.streams.labels
}
# fetch minibatch (first sequence)
data = mbs.next_minibatch(1, input_map=input_map)
trainer.train_minibatch(data)
assert learner.learning_rate() == 0.3
# fetch minibatch (second sequence, sweep ends at this point)
data = mbs.next_minibatch(1, input_map=input_map)
trainer.train_minibatch(data)
assert learner.learning_rate() == 0.2
# fetch minibatch (both sequences -- entire sweep in one go)
data = mbs.next_minibatch(9, input_map=input_map)
trainer.train_minibatch(data)
assert learner.learning_rate() == 0.1
# fetch minibatch (multiple sweeps)
data = mbs.next_minibatch(30, input_map=input_map)
trainer.train_minibatch(data, [z.output])
assert learner.learning_rate() == 0.0

Просмотреть файл

@ -7,7 +7,7 @@
from . import cntk_py
from .device import use_default_device
from .utils import sanitize_var_map, sanitize_function, typemap, value_to_seq
from .io import _py_dict_to_cntk_dict
from .io import _py_dict_to_cntk_dict, MinibatchData
__doc__= '''\
A trainer encapsulates the overall training process and employs one or more
@ -78,18 +78,36 @@ class Trainer(cntk_py.Trainer):
device = use_default_device()
if arguments:
arguments = sanitize_var_map(self.model.arguments, arguments)
arguments = sanitize_var_map(self.model.arguments, arguments,
extract_values_from_minibatch_data = False)
contains_minibatch_data = False
if (len(arguments) > 0):
value = next(iter(arguments.values()))
contains_minibatch_data = isinstance(value, MinibatchData)
if outputs:
output_map = {v: None for v in outputs}
updated = super(Trainer, self).train_minibatch(arguments,
if contains_minibatch_data:
updated = super(Trainer, self).train_minibatch_overload_for_minibatchdata(
arguments, output_map, device)
else:
updated = super(Trainer, self).train_minibatch(arguments,
output_map, device)
for k,v in output_map.items():
output_map[k] = value_to_seq(v)
return updated, output_map
else:
updated = super(Trainer, self).train_minibatch(arguments, device)
if contains_minibatch_data:
updated = super(Trainer, self).train_minibatch_overload_for_minibatchdata(
arguments, device)
else:
updated = super(Trainer, self).train_minibatch(arguments,
device)
return updated

Просмотреть файл

@ -315,7 +315,7 @@ def sanitize_function(arg):
def sanitize_var_map(op_arguments, arguments, precision=None,
device=None):
device=None, extract_values_from_minibatch_data=True):
'''
Sanitizes a dictionary of `Variable` s to input data such that it can be
handed off to the evaluation methods
@ -361,6 +361,12 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
one of 'float' 'float32, 'double', 'float64', or None
device (:class:`~cntk.device.DeviceDescriptor`, default None): device
this value should be put on
extract_values_from_minibatch_data (`bool`, defaults to `True`): specifies
if :class:`~cntk.io.MinibatchData` instances in the arguments map are
converted to the underlying value (:class:`Value`) instances (default),
or if they should remain intact, as they contain additional meta
information required by the Trainer (specifically, by the
:meth:`~cntk.Trainer.train_minibatch` method).
Returns:
`dict` that maps variables to sanitized batches
@ -436,9 +442,10 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
'sequence begin markers' % (sample_sizes, len(seq_starts)))
if isinstance(batch, MinibatchData):
batch = batch.m_data
elif not isinstance(batch, cntk_py.Value):
if isinstance(batch, MinibatchData) and extract_values_from_minibatch_data:
batch = batch.data
if not (isinstance(batch, MinibatchData) or isinstance(batch, cntk_py.Value)):
batch = sanitize_batch(var, batch, seq_starts, device)
var_map[var] = batch

Просмотреть файл

@ -19,7 +19,7 @@ def test_rnn_error(device_id):
error, loss = train_sequence_classifier()
expected_error = 0.333333
expected_loss = 1.060453
expected_loss = 1.12
assert np.allclose(error, expected_error, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(loss, expected_loss, atol=TOLERANCE_ABSOLUTE)