Updated manuals with new learning APIs

This commit is contained in:
yuqtang 2017-11-12 10:32:37 -08:00
Родитель 9e549e7258
Коммит b0df0a13e6
3 изменённых файлов: 48 добавлений и 22 удалений

Просмотреть файл

@ -369,7 +369,7 @@
" return {'next_seq_idx': self.next_seq_idx}\n", " return {'next_seq_idx': self.next_seq_idx}\n",
" \n", " \n",
" def restore_from_checkpoint(self, state):\n", " def restore_from_checkpoint(self, state):\n",
" self.next_seq_idx = state['next_seq_idx']\n" " self.next_seq_idx = state['next_seq_idx']"
] ]
}, },
{ {
@ -390,7 +390,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -419,7 +421,7 @@
"loss = C.cross_entropy_with_softmax(z, label)\n", "loss = C.cross_entropy_with_softmax(z, label)\n",
"errs = C.classification_error(z, label)\n", "errs = C.classification_error(z, label)\n",
"local_learner = C.sgd(z.parameters, \n", "local_learner = C.sgd(z.parameters, \n",
" C.learning_rate_schedule(0.5, unit = C.UnitType.sample))\n", " C.learning_parameter_schedule_per_sample(0.5))\n",
"dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n", "dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n",
"# and train\n", "# and train\n",
"trainer = C.Trainer(z, (loss, errs), \n", "trainer = C.Trainer(z, (loss, errs), \n",
@ -476,5 +478,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 0
} }

Просмотреть файл

@ -67,7 +67,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -120,7 +122,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -149,7 +153,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -173,7 +179,7 @@
], ],
"source": [ "source": [
"learner = C.sgd(model.parameters,\n", "learner = C.sgd(model.parameters,\n",
" C.learning_rate_schedule(0.1, C.UnitType.minibatch))\n", " C.learning_parameter_schedule(0.1))\n",
"progress_writer = C.logging.ProgressPrinter(0)\n", "progress_writer = C.logging.ProgressPrinter(0)\n",
"\n", "\n",
"train_summary = loss.train((X_train_lr, Y_train_lr), parameter_learners=[learner],\n", "train_summary = loss.train((X_train_lr, Y_train_lr), parameter_learners=[learner],\n",
@ -190,7 +196,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"data": { "data": {
@ -271,7 +279,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -348,7 +358,7 @@
"z = model(x/255.0) #scale the input to 0-1 range\n", "z = model(x/255.0) #scale the input to 0-1 range\n",
"loss = C.cross_entropy_with_softmax(z, y)\n", "loss = C.cross_entropy_with_softmax(z, y)\n",
"learner = C.sgd(z.parameters,\n", "learner = C.sgd(z.parameters,\n",
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))" " C.learning_parameter_schedule(0.05))"
] ]
}, },
{ {
@ -365,7 +375,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 9,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -423,7 +435,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -453,7 +467,7 @@
"z1 = model(x/255.0) #scale the input to 0-1 range\n", "z1 = model(x/255.0) #scale the input to 0-1 range\n",
"loss = C.cross_entropy_with_softmax(z1, y)\n", "loss = C.cross_entropy_with_softmax(z1, y)\n",
"learner = C.sgd(z1.parameters,\n", "learner = C.sgd(z1.parameters,\n",
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))\n", " C.learning_parameter_schedule(0.05))\n",
"\n", "\n",
"num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size\n", "num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size\n",
"\n", "\n",
@ -613,7 +627,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 12,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -664,6 +680,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 13,
"metadata": { "metadata": {
"collapsed": false,
"scrolled": true "scrolled": true
}, },
"outputs": [ "outputs": [

Просмотреть файл

@ -67,7 +67,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 30,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -130,7 +132,7 @@
"\n", "\n",
"# Create a learner and a trainer and a progress writer to \n", "# Create a learner and a trainer and a progress writer to \n",
"# output current progress\n", "# output current progress\n",
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n", "learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n", "trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
"\n", "\n",
"# Now let's create a minibatch source for our input file\n", "# Now let's create a minibatch source for our input file\n",
@ -186,7 +188,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 31,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -231,7 +235,7 @@
"z = model(features)\n", "z = model(features)\n",
"loss = cntk.squared_error(z, label)\n", "loss = cntk.squared_error(z, label)\n",
"\n", "\n",
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n", "learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n", "trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
"\n", "\n",
"# Try to restore if the checkpoint exists\n", "# Try to restore if the checkpoint exists\n",
@ -281,7 +285,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 32,
"metadata": {}, "metadata": {
"collapsed": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -330,7 +336,7 @@
"\n", "\n",
"# Make sure the learner is distributed\n", "# Make sure the learner is distributed\n",
"distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n", "distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n",
" cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n", " cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1)))\n",
"trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n", "trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n",
"\n", "\n",
"if os.path.exists(checkpoint):\n", "if os.path.exists(checkpoint):\n",
@ -385,6 +391,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 33,
"metadata": { "metadata": {
"collapsed": false,
"scrolled": true "scrolled": true
}, },
"outputs": [ "outputs": [
@ -441,7 +448,7 @@
"\n", "\n",
"criterion = criterion_factory(features, label)\n", "criterion = criterion_factory(features, label)\n",
"learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n", "learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n",
" cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n", " cntk.learning_parameter_schedule_per_sample(0.1)))\n",
"progress_writer = cntk.logging.ProgressPrinter(freq=10)\n", "progress_writer = cntk.logging.ProgressPrinter(freq=10)\n",
"checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n", "checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n",
"test_config = cntk.TestConfig(test_mb_source)\n", "test_config = cntk.TestConfig(test_mb_source)\n",