Integrate kedeng/addLr into master

This commit is contained in:
Project Philly 2017-04-07 18:29:17 -07:00
Родитель 9330def020 e7c16095c0
Коммит e1f48f4c14
19 изменённых файлов: 42 добавлений и 34 удалений

Просмотреть файл

@ -84,7 +84,8 @@ def simple_mnist(tensorboard_logdir=None):
progress_writers.append(TensorBoardProgressWriter(freq=10, log_dir=tensorboard_logdir, model=z))
# Instantiate the trainer object to drive the model training
trainer = Trainer(z, (ce, pe), adadelta(z.parameters), progress_writers)
lr = learning_rate_schedule(1, UnitType.sample)
trainer = Trainer(z, (ce, pe), adadelta(z.parameters, lr), progress_writers)
training_session(
trainer=trainer,

Просмотреть файл

@ -4311,6 +4311,7 @@ namespace CNTK
/// Create an instance of the CNTK built-in AdaDelta learner.
///
CNTK_API LearnerPtr AdaDeltaLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
double rho = 0.95,
double epsilon = 1e-8,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());

Просмотреть файл

@ -531,9 +531,10 @@ namespace CNTK
LearnerAdaDelta::LearnerAdaDelta(
const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
double rho, double epsilon,
AdditionalLearningOptions additionalOptions)
: LearnerBase(parameters, LearningRateSchedule(1, LearningRateSchedule::UnitType::Sample), additionalOptions, /*allocateSmoothGradients*/ false),
: LearnerBase(parameters, learningRateSchedule, additionalOptions, /*allocateSmoothGradients*/ false),
m_rho(rho), m_epsilon(epsilon)
{
for (const auto& parameter : parameters)
@ -556,7 +557,9 @@ namespace CNTK
{
GET_WRITABLE_MATRICES
smoothedGradientMatrix->AdaDeltaUpdate(*gradientMatrix, *parameterMatrix, (ElementType)m_rho, (ElementType)m_epsilon);
const auto learningRate = LearningRate(trainingSampleCount);
smoothedGradientMatrix->AdaDeltaUpdate(*gradientMatrix, *parameterMatrix, (ElementType)learningRate, (ElementType)m_rho, (ElementType)m_epsilon);
}
/*static*/ const double LearnerFSAdaGrad::s_targetAdagradAvDenom = 1.0;
@ -760,9 +763,10 @@ namespace CNTK
}
LearnerPtr AdaDeltaLearner(const vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
double rho, double epsilon,
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
{
return MakeSharedObject<LearnerAdaDelta>(parameters, rho, epsilon, additionalOptions);
return MakeSharedObject<LearnerAdaDelta>(parameters, learningRateSchedule, rho, epsilon, additionalOptions);
}
}

Просмотреть файл

@ -210,6 +210,7 @@ namespace CNTK
public:
LearnerAdaDelta(
const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
double rho, double epsilon,
AdditionalLearningOptions additionalOptions);

Просмотреть файл

@ -1390,7 +1390,7 @@ ElemType CPUMatrix<ElemType>::RmsProp(CPUMatrix<ElemType>& gradients,
}
template <class ElemType>
void CPUMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
void CPUMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
{
size_t numColsNeeded = 2 * gradients.GetNumCols();
@ -1418,7 +1418,7 @@ void CPUMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<Ele
ElemType x2 = smoothX2[i];
ElemType deltaX = -sqrt(x2 + epsilon) / sqrt(adaSqr + epsilon) * g;
smoothX2[i] = rho * smoothX2[i] + (1 - rho) * deltaX * deltaX;
val[i] += deltaX;
val[i] += learningRate * deltaX;
}
}

Просмотреть файл

@ -107,7 +107,7 @@ public:
ElemType RMS_WGT_MIN,
const bool needAveMultiplier);
void AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon);
void AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
void Reshape(const size_t numRows, const size_t numCols);

Просмотреть файл

@ -1395,7 +1395,7 @@ ElemType CPUSparseMatrix<ElemType>::Adagrad(CPUMatrix<ElemType>& c, const bool n
}
template <class ElemType>
void CPUSparseMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
void CPUSparseMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
{
size_t numColsNeeded = 2 * GetNumCols();
@ -1433,7 +1433,7 @@ void CPUSparseMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemT
ElemType x2 = smoothX2[denseIndex];
ElemType deltaX = -sqrt(x2 + epsilon) / sqrt(adaSqr + epsilon) * g;
smoothX2[denseIndex] = rho * smoothX2[denseIndex] + (1 - rho) * deltaX * deltaX;
val[denseIndex] += deltaX;
val[denseIndex] += learningRate * deltaX;
}
}
}

Просмотреть файл

@ -231,7 +231,7 @@ public:
public:
void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum = true);
ElemType Adagrad(CPUMatrix<ElemType>& c, const bool needAveMultiplier);
void AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon);
void AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
public:
CPUSparseMatrix<ElemType>& InplaceTruncateTop(const ElemType threshold);

Просмотреть файл

@ -1533,7 +1533,7 @@ ElemType GPUMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& gradients,
}
template <class ElemType>
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
{
size_t numColsNeeded = 2 * gradients.GetNumCols();
@ -1547,7 +1547,7 @@ void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<Ele
size_t n = gradients.GetNumElements();
int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock;
_adadelta<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(n, gradients.Data(), Data(), Data() + n, functionValues.Data(), rho, epsilon);
_adadelta<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(n, gradients.Data(), Data(), Data() + n, functionValues.Data(), learningRate, rho, epsilon);
}
template <class ElemType>

Просмотреть файл

@ -240,7 +240,7 @@ public:
ElemType RMS_WGT_MIN,
const bool needAveMultiplier);
void AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon);
void AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
void Reshape(const size_t numRows, const size_t numCols);

Просмотреть файл

@ -5250,7 +5250,7 @@ __global__ void _adam4BlockSparseCol(CUDA_LONG size,
template <class ElemType>
__global__ void _adadelta(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, ElemType* smoothX2, ElemType* val,
ElemType rho, ElemType epsilon)
ElemType learningRate, ElemType rho, ElemType epsilon)
{
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_LONG stride = blockDim.x * gridDim.x;
@ -5271,7 +5271,7 @@ __global__ void _adadelta(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, E
}
smoothX2[idx] = rho * smoothX2[idx] + (1.0f - rho) * deltaX * deltaX;
val[idx] += deltaX;
val[idx] += learningRate * deltaX;
}
}
@ -5279,7 +5279,7 @@ template <class ElemType>
__global__ void _adadelta4BlockSparseCol(CUDA_LONG size,
ElemType* grad_bsc, const GPUSPARSE_INDEX_TYPE* colOrRow2blockId, const size_t len,
ElemType* smoothAda, ElemType* smoothX2, ElemType* val,
ElemType rho, ElemType epsilon)
ElemType learningRate, ElemType rho, ElemType epsilon)
{
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_LONG stride = blockDim.x * gridDim.x;
@ -5300,7 +5300,7 @@ __global__ void _adadelta4BlockSparseCol(CUDA_LONG size,
}
smoothX2[idx] = rho * smoothX2[idx] + (1.0f - rho) * deltaX * deltaX;
val[idx] += deltaX;
val[idx] += learningRate * deltaX;
}
}

Просмотреть файл

@ -1713,7 +1713,7 @@ ElemType GPUSparseMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& c,
}
template <class ElemType>
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType rho, ElemType epsilon)
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
{
if (GetFormat() != MatrixFormat::matrixFormatSparseBlockCol)
{
@ -1735,7 +1735,7 @@ void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemTy
_adadelta4BlockSparseCol<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(
n, Data(), ColOrRow2BlockId(), GetNumRows(),
c.Data(), c.Data() + n, functionValues.Data(),
rho, epsilon);
learningRate, rho, epsilon);
}
// sparse X dense = dense

Просмотреть файл

@ -423,7 +423,7 @@ public:
void FSAdagrad(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
ElemType RmsProp(GPUMatrix<ElemType>& c, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
void Adam(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
void AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType rho, ElemType epsilon);
void AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
static void Multiply(const GPUSparseMatrix<ElemType>& S, const GPUMatrix<ElemType>& D, GPUMatrix<ElemType>& C);
static void Multiply(const GPUMatrix<ElemType>& D, const GPUSparseMatrix<ElemType>& S, GPUMatrix<ElemType>& C);

Просмотреть файл

@ -1789,15 +1789,15 @@ ElemType Matrix<ElemType>::RmsProp(Matrix<ElemType>& gradients,
template <class ElemType>
void Matrix<ElemType>::AdaDeltaUpdate(Matrix<ElemType>& gradients,
Matrix<ElemType>& functionValues,
ElemType rho, ElemType epsilon)
ElemType learningRate, ElemType rho, ElemType epsilon)
{
DecideAndMoveToRightDevice(*this, gradients);
DISPATCH_MATRIX_ON_FLAG(&gradients, &gradients,
{ return m_CPUMatrix->AdaDelta(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix, rho, epsilon); SetDataLocation(CPU); },
{ return m_GPUMatrix->AdaDelta(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix, rho, epsilon); SetDataLocation(GPU); },
{ return gradients.m_CPUSparseMatrix->AdaDelta(*m_CPUMatrix, *functionValues.m_CPUMatrix, rho, epsilon); SetDataLocation(CPU); },
{ return gradients.m_GPUSparseMatrix->AdaDelta(*m_GPUMatrix, *functionValues.m_GPUMatrix, rho, epsilon); SetDataLocation(GPU); });
{ return m_CPUMatrix->AdaDelta(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix, learningRate, rho, epsilon); SetDataLocation(CPU); },
{ return m_GPUMatrix->AdaDelta(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix, learningRate, rho, epsilon); SetDataLocation(GPU); },
{ return gradients.m_CPUSparseMatrix->AdaDelta(*m_CPUMatrix, *functionValues.m_CPUMatrix, learningRate, rho, epsilon); SetDataLocation(CPU); },
{ return gradients.m_GPUSparseMatrix->AdaDelta(*m_GPUMatrix, *functionValues.m_GPUMatrix, learningRate, rho, epsilon); SetDataLocation(GPU); });
}
template <class ElemType>

Просмотреть файл

@ -217,7 +217,7 @@ public:
ElemType RmsProp(Matrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
void AdaDeltaUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& functionvalues, ElemType rho, ElemType epsilon);
void AdaDeltaUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& functionvalues, ElemType learningRatePerSample, ElemType rho, ElemType epsilon);
void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 10000, bool growOnly = true); // by default we only reallocate if need to grow
void Resize(const Matrix<ElemType>& other) // TODO: Should this carry over numNZElemToReserve for sparse matrices?

Просмотреть файл

@ -276,7 +276,7 @@ ElemType GPUSparseMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>&, ElemType, Elem
}
template<class ElemType>
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType rho, ElemType epsilon)
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
{
}
@ -1110,7 +1110,7 @@ ElemType GPUMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& gradients, ElemType R
}
template <class ElemType>
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
{
}

Просмотреть файл

@ -107,8 +107,8 @@ BOOST_FIXTURE_TEST_CASE(AdaDeltaSparse, MatrixLearnerFixture)
// run learner
RunOnDevices([this]()
{
matSG.AdaDeltaUpdate(matG, matM, 0.95f, 1e-8f);
matSGsparse.AdaDeltaUpdate(matGsparseBSC, matMsparse, 0.95f, 1e-8f);
matSG.AdaDeltaUpdate(matG, matM, 0.5f, 0.95f, 1e-8f);
matSGsparse.AdaDeltaUpdate(matGsparseBSC, matMsparse, 0.5f, 0.95f, 1e-8f);
BOOST_CHECK(matSG.IsEqualTo(matSGsparse, c_epsilonFloatE4));
BOOST_CHECK(matM.IsEqualTo(matMsparse, c_epsilonFloatE4));

Просмотреть файл

@ -558,17 +558,18 @@ def nesterov(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
additional_options)
@typemap
def adadelta(parameters, rho=0.95, epsilon=1e-8,
def adadelta(parameters, lr=learning_rate_schedule(1, UnitType.sample), rho=0.95, epsilon=1e-8,
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
gradient_clipping_with_truncation=True):
'''adadelta(parameters, rho, epsilon, l1_regularization_weight=0, l2_regularization_weight=0, gaussian_noise_injection_std_dev=0, gradient_clipping_threshold_per_sample=np.inf, gradient_clipping_with_truncation=True)
'''adadelta(parameters, lr, rho, epsilon, l1_regularization_weight=0, l2_regularization_weight=0, gaussian_noise_injection_std_dev=0, gradient_clipping_threshold_per_sample=np.inf, gradient_clipping_with_truncation=True)
Creates an AdaDelta learner instance to learn the parameters. See [1] for
more information.
Args:
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
rho (float): exponential smooth factor for each minibatch.
epsilon (float): epsilon for sqrt.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
@ -600,7 +601,7 @@ def adadelta(parameters, rho=0.95, epsilon=1e-8,
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
return cntk_py.ada_delta_learner(parameters, rho, epsilon,
return cntk_py.ada_delta_learner(parameters, lr, rho, epsilon,
additional_options)

Просмотреть файл

@ -103,7 +103,7 @@ def test_learner_init():
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)
rmsprop(res.parameters, lr_per_sample, gamma, inc, dec, max, min, True)
adadelta(res.parameters)
adadelta(res.parameters, lr_per_sample)
def test_learner_update():
i = input(shape=(1,), needs_gradient=True, name='a')