Updated manuals with new learning APIs
This commit is contained in:
Родитель
9e549e7258
Коммит
b0df0a13e6
|
@ -369,7 +369,7 @@
|
|||
" return {'next_seq_idx': self.next_seq_idx}\n",
|
||||
" \n",
|
||||
" def restore_from_checkpoint(self, state):\n",
|
||||
" self.next_seq_idx = state['next_seq_idx']\n"
|
||||
" self.next_seq_idx = state['next_seq_idx']"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -390,7 +390,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -419,7 +421,7 @@
|
|||
"loss = C.cross_entropy_with_softmax(z, label)\n",
|
||||
"errs = C.classification_error(z, label)\n",
|
||||
"local_learner = C.sgd(z.parameters, \n",
|
||||
" C.learning_rate_schedule(0.5, unit = C.UnitType.sample))\n",
|
||||
" C.learning_parameter_schedule_per_sample(0.5))\n",
|
||||
"dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n",
|
||||
"# and train\n",
|
||||
"trainer = C.Trainer(z, (loss, errs), \n",
|
||||
|
@ -476,5 +478,5 @@
|
|||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
|
|
@ -67,7 +67,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -120,7 +122,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -149,7 +153,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -173,7 +179,7 @@
|
|||
],
|
||||
"source": [
|
||||
"learner = C.sgd(model.parameters,\n",
|
||||
" C.learning_rate_schedule(0.1, C.UnitType.minibatch))\n",
|
||||
" C.learning_parameter_schedule(0.1))\n",
|
||||
"progress_writer = C.logging.ProgressPrinter(0)\n",
|
||||
"\n",
|
||||
"train_summary = loss.train((X_train_lr, Y_train_lr), parameter_learners=[learner],\n",
|
||||
|
@ -190,7 +196,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
|
@ -271,7 +279,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -348,7 +358,7 @@
|
|||
"z = model(x/255.0) #scale the input to 0-1 range\n",
|
||||
"loss = C.cross_entropy_with_softmax(z, y)\n",
|
||||
"learner = C.sgd(z.parameters,\n",
|
||||
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))"
|
||||
" C.learning_parameter_schedule(0.05))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -365,7 +375,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -423,7 +435,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -453,7 +467,7 @@
|
|||
"z1 = model(x/255.0) #scale the input to 0-1 range\n",
|
||||
"loss = C.cross_entropy_with_softmax(z1, y)\n",
|
||||
"learner = C.sgd(z1.parameters,\n",
|
||||
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))\n",
|
||||
" C.learning_parameter_schedule(0.05))\n",
|
||||
"\n",
|
||||
"num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size\n",
|
||||
"\n",
|
||||
|
@ -613,7 +627,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -664,6 +680,7 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
|
|
|
@ -67,7 +67,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -130,7 +132,7 @@
|
|||
"\n",
|
||||
"# Create a learner and a trainer and a progress writer to \n",
|
||||
"# output current progress\n",
|
||||
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n",
|
||||
"learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
|
||||
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
|
||||
"\n",
|
||||
"# Now let's create a minibatch source for our input file\n",
|
||||
|
@ -186,7 +188,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -231,7 +235,7 @@
|
|||
"z = model(features)\n",
|
||||
"loss = cntk.squared_error(z, label)\n",
|
||||
"\n",
|
||||
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n",
|
||||
"learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
|
||||
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
|
||||
"\n",
|
||||
"# Try to restore if the checkpoint exists\n",
|
||||
|
@ -281,7 +285,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
|
@ -330,7 +336,7 @@
|
|||
"\n",
|
||||
"# Make sure the learner is distributed\n",
|
||||
"distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n",
|
||||
" cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n",
|
||||
" cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1)))\n",
|
||||
"trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n",
|
||||
"\n",
|
||||
"if os.path.exists(checkpoint):\n",
|
||||
|
@ -385,6 +391,7 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
|
@ -441,7 +448,7 @@
|
|||
"\n",
|
||||
"criterion = criterion_factory(features, label)\n",
|
||||
"learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n",
|
||||
" cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n",
|
||||
" cntk.learning_parameter_schedule_per_sample(0.1)))\n",
|
||||
"progress_writer = cntk.logging.ProgressPrinter(freq=10)\n",
|
||||
"checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n",
|
||||
"test_config = cntk.TestConfig(test_mb_source)\n",
|
||||
|
|
Загрузка…
Ссылка в новой задаче