Integrate kedeng/addLr into master
This commit is contained in:
Коммит
e1f48f4c14
|
@ -84,7 +84,8 @@ def simple_mnist(tensorboard_logdir=None):
|
|||
progress_writers.append(TensorBoardProgressWriter(freq=10, log_dir=tensorboard_logdir, model=z))
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(z, (ce, pe), adadelta(z.parameters), progress_writers)
|
||||
lr = learning_rate_schedule(1, UnitType.sample)
|
||||
trainer = Trainer(z, (ce, pe), adadelta(z.parameters, lr), progress_writers)
|
||||
|
||||
training_session(
|
||||
trainer=trainer,
|
||||
|
|
|
@ -4311,6 +4311,7 @@ namespace CNTK
|
|||
/// Create an instance of the CNTK built-in AdaDelta learner.
|
||||
///
|
||||
CNTK_API LearnerPtr AdaDeltaLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
double rho = 0.95,
|
||||
double epsilon = 1e-8,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
|
|
@ -531,9 +531,10 @@ namespace CNTK
|
|||
|
||||
LearnerAdaDelta::LearnerAdaDelta(
|
||||
const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
double rho, double epsilon,
|
||||
AdditionalLearningOptions additionalOptions)
|
||||
: LearnerBase(parameters, LearningRateSchedule(1, LearningRateSchedule::UnitType::Sample), additionalOptions, /*allocateSmoothGradients*/ false),
|
||||
: LearnerBase(parameters, learningRateSchedule, additionalOptions, /*allocateSmoothGradients*/ false),
|
||||
m_rho(rho), m_epsilon(epsilon)
|
||||
{
|
||||
for (const auto& parameter : parameters)
|
||||
|
@ -556,7 +557,9 @@ namespace CNTK
|
|||
{
|
||||
GET_WRITABLE_MATRICES
|
||||
|
||||
smoothedGradientMatrix->AdaDeltaUpdate(*gradientMatrix, *parameterMatrix, (ElementType)m_rho, (ElementType)m_epsilon);
|
||||
const auto learningRate = LearningRate(trainingSampleCount);
|
||||
|
||||
smoothedGradientMatrix->AdaDeltaUpdate(*gradientMatrix, *parameterMatrix, (ElementType)learningRate, (ElementType)m_rho, (ElementType)m_epsilon);
|
||||
}
|
||||
|
||||
/*static*/ const double LearnerFSAdaGrad::s_targetAdagradAvDenom = 1.0;
|
||||
|
@ -760,9 +763,10 @@ namespace CNTK
|
|||
}
|
||||
|
||||
LearnerPtr AdaDeltaLearner(const vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
double rho, double epsilon,
|
||||
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerAdaDelta>(parameters, rho, epsilon, additionalOptions);
|
||||
return MakeSharedObject<LearnerAdaDelta>(parameters, learningRateSchedule, rho, epsilon, additionalOptions);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -210,6 +210,7 @@ namespace CNTK
|
|||
public:
|
||||
LearnerAdaDelta(
|
||||
const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
double rho, double epsilon,
|
||||
AdditionalLearningOptions additionalOptions);
|
||||
|
||||
|
|
|
@ -1390,7 +1390,7 @@ ElemType CPUMatrix<ElemType>::RmsProp(CPUMatrix<ElemType>& gradients,
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
|
||||
void CPUMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
size_t numColsNeeded = 2 * gradients.GetNumCols();
|
||||
|
||||
|
@ -1418,7 +1418,7 @@ void CPUMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<Ele
|
|||
ElemType x2 = smoothX2[i];
|
||||
ElemType deltaX = -sqrt(x2 + epsilon) / sqrt(adaSqr + epsilon) * g;
|
||||
smoothX2[i] = rho * smoothX2[i] + (1 - rho) * deltaX * deltaX;
|
||||
val[i] += deltaX;
|
||||
val[i] += learningRate * deltaX;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ public:
|
|||
ElemType RMS_WGT_MIN,
|
||||
const bool needAveMultiplier);
|
||||
|
||||
void AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon);
|
||||
void AdaDelta(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
|
||||
|
||||
void Reshape(const size_t numRows, const size_t numCols);
|
||||
|
||||
|
|
|
@ -1395,7 +1395,7 @@ ElemType CPUSparseMatrix<ElemType>::Adagrad(CPUMatrix<ElemType>& c, const bool n
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUSparseMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
|
||||
void CPUSparseMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
size_t numColsNeeded = 2 * GetNumCols();
|
||||
|
||||
|
@ -1433,7 +1433,7 @@ void CPUSparseMatrix<ElemType>::AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemT
|
|||
ElemType x2 = smoothX2[denseIndex];
|
||||
ElemType deltaX = -sqrt(x2 + epsilon) / sqrt(adaSqr + epsilon) * g;
|
||||
smoothX2[denseIndex] = rho * smoothX2[denseIndex] + (1 - rho) * deltaX * deltaX;
|
||||
val[denseIndex] += deltaX;
|
||||
val[denseIndex] += learningRate * deltaX;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -231,7 +231,7 @@ public:
|
|||
public:
|
||||
void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum = true);
|
||||
ElemType Adagrad(CPUMatrix<ElemType>& c, const bool needAveMultiplier);
|
||||
void AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon);
|
||||
void AdaDelta(CPUMatrix<ElemType>& c, CPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
|
||||
|
||||
public:
|
||||
CPUSparseMatrix<ElemType>& InplaceTruncateTop(const ElemType threshold);
|
||||
|
|
|
@ -1533,7 +1533,7 @@ ElemType GPUMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& gradients,
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
|
||||
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
size_t numColsNeeded = 2 * gradients.GetNumCols();
|
||||
|
||||
|
@ -1547,7 +1547,7 @@ void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<Ele
|
|||
|
||||
size_t n = gradients.GetNumElements();
|
||||
int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock;
|
||||
_adadelta<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(n, gradients.Data(), Data(), Data() + n, functionValues.Data(), rho, epsilon);
|
||||
_adadelta<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(n, gradients.Data(), Data(), Data() + n, functionValues.Data(), learningRate, rho, epsilon);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -240,7 +240,7 @@ public:
|
|||
ElemType RMS_WGT_MIN,
|
||||
const bool needAveMultiplier);
|
||||
|
||||
void AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon);
|
||||
void AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
|
||||
|
||||
void Reshape(const size_t numRows, const size_t numCols);
|
||||
|
||||
|
|
|
@ -5250,7 +5250,7 @@ __global__ void _adam4BlockSparseCol(CUDA_LONG size,
|
|||
|
||||
template <class ElemType>
|
||||
__global__ void _adadelta(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, ElemType* smoothX2, ElemType* val,
|
||||
ElemType rho, ElemType epsilon)
|
||||
ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
CUDA_LONG stride = blockDim.x * gridDim.x;
|
||||
|
@ -5271,7 +5271,7 @@ __global__ void _adadelta(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, E
|
|||
}
|
||||
|
||||
smoothX2[idx] = rho * smoothX2[idx] + (1.0f - rho) * deltaX * deltaX;
|
||||
val[idx] += deltaX;
|
||||
val[idx] += learningRate * deltaX;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5279,7 +5279,7 @@ template <class ElemType>
|
|||
__global__ void _adadelta4BlockSparseCol(CUDA_LONG size,
|
||||
ElemType* grad_bsc, const GPUSPARSE_INDEX_TYPE* colOrRow2blockId, const size_t len,
|
||||
ElemType* smoothAda, ElemType* smoothX2, ElemType* val,
|
||||
ElemType rho, ElemType epsilon)
|
||||
ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
CUDA_LONG stride = blockDim.x * gridDim.x;
|
||||
|
@ -5300,7 +5300,7 @@ __global__ void _adadelta4BlockSparseCol(CUDA_LONG size,
|
|||
}
|
||||
|
||||
smoothX2[idx] = rho * smoothX2[idx] + (1.0f - rho) * deltaX * deltaX;
|
||||
val[idx] += deltaX;
|
||||
val[idx] += learningRate * deltaX;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1713,7 +1713,7 @@ ElemType GPUSparseMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& c,
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType rho, ElemType epsilon)
|
||||
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
if (GetFormat() != MatrixFormat::matrixFormatSparseBlockCol)
|
||||
{
|
||||
|
@ -1735,7 +1735,7 @@ void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemTy
|
|||
_adadelta4BlockSparseCol<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(
|
||||
n, Data(), ColOrRow2BlockId(), GetNumRows(),
|
||||
c.Data(), c.Data() + n, functionValues.Data(),
|
||||
rho, epsilon);
|
||||
learningRate, rho, epsilon);
|
||||
}
|
||||
|
||||
// sparse X dense = dense
|
||||
|
|
|
@ -423,7 +423,7 @@ public:
|
|||
void FSAdagrad(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
|
||||
ElemType RmsProp(GPUMatrix<ElemType>& c, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
|
||||
void Adam(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
|
||||
void AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType rho, ElemType epsilon);
|
||||
void AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType learningRate, ElemType rho, ElemType epsilon);
|
||||
|
||||
static void Multiply(const GPUSparseMatrix<ElemType>& S, const GPUMatrix<ElemType>& D, GPUMatrix<ElemType>& C);
|
||||
static void Multiply(const GPUMatrix<ElemType>& D, const GPUSparseMatrix<ElemType>& S, GPUMatrix<ElemType>& C);
|
||||
|
|
|
@ -1789,15 +1789,15 @@ ElemType Matrix<ElemType>::RmsProp(Matrix<ElemType>& gradients,
|
|||
template <class ElemType>
|
||||
void Matrix<ElemType>::AdaDeltaUpdate(Matrix<ElemType>& gradients,
|
||||
Matrix<ElemType>& functionValues,
|
||||
ElemType rho, ElemType epsilon)
|
||||
ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
DecideAndMoveToRightDevice(*this, gradients);
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, &gradients,
|
||||
{ return m_CPUMatrix->AdaDelta(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix, rho, epsilon); SetDataLocation(CPU); },
|
||||
{ return m_GPUMatrix->AdaDelta(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix, rho, epsilon); SetDataLocation(GPU); },
|
||||
{ return gradients.m_CPUSparseMatrix->AdaDelta(*m_CPUMatrix, *functionValues.m_CPUMatrix, rho, epsilon); SetDataLocation(CPU); },
|
||||
{ return gradients.m_GPUSparseMatrix->AdaDelta(*m_GPUMatrix, *functionValues.m_GPUMatrix, rho, epsilon); SetDataLocation(GPU); });
|
||||
{ return m_CPUMatrix->AdaDelta(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix, learningRate, rho, epsilon); SetDataLocation(CPU); },
|
||||
{ return m_GPUMatrix->AdaDelta(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix, learningRate, rho, epsilon); SetDataLocation(GPU); },
|
||||
{ return gradients.m_CPUSparseMatrix->AdaDelta(*m_CPUMatrix, *functionValues.m_CPUMatrix, learningRate, rho, epsilon); SetDataLocation(CPU); },
|
||||
{ return gradients.m_GPUSparseMatrix->AdaDelta(*m_GPUMatrix, *functionValues.m_GPUMatrix, learningRate, rho, epsilon); SetDataLocation(GPU); });
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -217,7 +217,7 @@ public:
|
|||
|
||||
ElemType RmsProp(Matrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
|
||||
|
||||
void AdaDeltaUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& functionvalues, ElemType rho, ElemType epsilon);
|
||||
void AdaDeltaUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& functionvalues, ElemType learningRatePerSample, ElemType rho, ElemType epsilon);
|
||||
|
||||
void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 10000, bool growOnly = true); // by default we only reallocate if need to grow
|
||||
void Resize(const Matrix<ElemType>& other) // TODO: Should this carry over numNZElemToReserve for sparse matrices?
|
||||
|
|
|
@ -276,7 +276,7 @@ ElemType GPUSparseMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>&, ElemType, Elem
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType rho, ElemType epsilon)
|
||||
void GPUSparseMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>&c, GPUMatrix<ElemType>&functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -1110,7 +1110,7 @@ ElemType GPUMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& gradients, ElemType R
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType rho, ElemType epsilon)
|
||||
void GPUMatrix<ElemType>::AdaDelta(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -107,8 +107,8 @@ BOOST_FIXTURE_TEST_CASE(AdaDeltaSparse, MatrixLearnerFixture)
|
|||
// run learner
|
||||
RunOnDevices([this]()
|
||||
{
|
||||
matSG.AdaDeltaUpdate(matG, matM, 0.95f, 1e-8f);
|
||||
matSGsparse.AdaDeltaUpdate(matGsparseBSC, matMsparse, 0.95f, 1e-8f);
|
||||
matSG.AdaDeltaUpdate(matG, matM, 0.5f, 0.95f, 1e-8f);
|
||||
matSGsparse.AdaDeltaUpdate(matGsparseBSC, matMsparse, 0.5f, 0.95f, 1e-8f);
|
||||
|
||||
BOOST_CHECK(matSG.IsEqualTo(matSGsparse, c_epsilonFloatE4));
|
||||
BOOST_CHECK(matM.IsEqualTo(matMsparse, c_epsilonFloatE4));
|
||||
|
|
|
@ -558,17 +558,18 @@ def nesterov(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
|
|||
additional_options)
|
||||
|
||||
@typemap
|
||||
def adadelta(parameters, rho=0.95, epsilon=1e-8,
|
||||
def adadelta(parameters, lr=learning_rate_schedule(1, UnitType.sample), rho=0.95, epsilon=1e-8,
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
|
||||
gradient_clipping_with_truncation=True):
|
||||
'''adadelta(parameters, rho, epsilon, l1_regularization_weight=0, l2_regularization_weight=0, gaussian_noise_injection_std_dev=0, gradient_clipping_threshold_per_sample=np.inf, gradient_clipping_with_truncation=True)
|
||||
'''adadelta(parameters, lr, rho, epsilon, l1_regularization_weight=0, l2_regularization_weight=0, gaussian_noise_injection_std_dev=0, gradient_clipping_threshold_per_sample=np.inf, gradient_clipping_with_truncation=True)
|
||||
Creates an AdaDelta learner instance to learn the parameters. See [1] for
|
||||
more information.
|
||||
|
||||
Args:
|
||||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the root operator's ``parameters``.
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
rho (float): exponential smooth factor for each minibatch.
|
||||
epsilon (float): epsilon for sqrt.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
|
@ -600,7 +601,7 @@ def adadelta(parameters, rho=0.95, epsilon=1e-8,
|
|||
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
|
||||
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
|
||||
|
||||
return cntk_py.ada_delta_learner(parameters, rho, epsilon,
|
||||
return cntk_py.ada_delta_learner(parameters, lr, rho, epsilon,
|
||||
additional_options)
|
||||
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_learner_init():
|
|||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)
|
||||
rmsprop(res.parameters, lr_per_sample, gamma, inc, dec, max, min, True)
|
||||
|
||||
adadelta(res.parameters)
|
||||
adadelta(res.parameters, lr_per_sample)
|
||||
|
||||
def test_learner_update():
|
||||
i = input(shape=(1,), needs_gradient=True, name='a')
|
||||
|
|
Загрузка…
Ссылка в новой задаче