Updated manuals with new learning APIs
This commit is contained in:
Родитель
9e549e7258
Коммит
b0df0a13e6
|
@ -369,7 +369,7 @@
|
||||||
" return {'next_seq_idx': self.next_seq_idx}\n",
|
" return {'next_seq_idx': self.next_seq_idx}\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def restore_from_checkpoint(self, state):\n",
|
" def restore_from_checkpoint(self, state):\n",
|
||||||
" self.next_seq_idx = state['next_seq_idx']\n"
|
" self.next_seq_idx = state['next_seq_idx']"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -390,7 +390,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -419,7 +421,7 @@
|
||||||
"loss = C.cross_entropy_with_softmax(z, label)\n",
|
"loss = C.cross_entropy_with_softmax(z, label)\n",
|
||||||
"errs = C.classification_error(z, label)\n",
|
"errs = C.classification_error(z, label)\n",
|
||||||
"local_learner = C.sgd(z.parameters, \n",
|
"local_learner = C.sgd(z.parameters, \n",
|
||||||
" C.learning_rate_schedule(0.5, unit = C.UnitType.sample))\n",
|
" C.learning_parameter_schedule_per_sample(0.5))\n",
|
||||||
"dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n",
|
"dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n",
|
||||||
"# and train\n",
|
"# and train\n",
|
||||||
"trainer = C.Trainer(z, (loss, errs), \n",
|
"trainer = C.Trainer(z, (loss, errs), \n",
|
||||||
|
@ -476,5 +478,5 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -120,7 +122,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -149,7 +153,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -173,7 +179,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"learner = C.sgd(model.parameters,\n",
|
"learner = C.sgd(model.parameters,\n",
|
||||||
" C.learning_rate_schedule(0.1, C.UnitType.minibatch))\n",
|
" C.learning_parameter_schedule(0.1))\n",
|
||||||
"progress_writer = C.logging.ProgressPrinter(0)\n",
|
"progress_writer = C.logging.ProgressPrinter(0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"train_summary = loss.train((X_train_lr, Y_train_lr), parameter_learners=[learner],\n",
|
"train_summary = loss.train((X_train_lr, Y_train_lr), parameter_learners=[learner],\n",
|
||||||
|
@ -190,7 +196,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
|
@ -271,7 +279,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -348,7 +358,7 @@
|
||||||
"z = model(x/255.0) #scale the input to 0-1 range\n",
|
"z = model(x/255.0) #scale the input to 0-1 range\n",
|
||||||
"loss = C.cross_entropy_with_softmax(z, y)\n",
|
"loss = C.cross_entropy_with_softmax(z, y)\n",
|
||||||
"learner = C.sgd(z.parameters,\n",
|
"learner = C.sgd(z.parameters,\n",
|
||||||
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))"
|
" C.learning_parameter_schedule(0.05))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -365,7 +375,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -423,7 +435,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -453,7 +467,7 @@
|
||||||
"z1 = model(x/255.0) #scale the input to 0-1 range\n",
|
"z1 = model(x/255.0) #scale the input to 0-1 range\n",
|
||||||
"loss = C.cross_entropy_with_softmax(z1, y)\n",
|
"loss = C.cross_entropy_with_softmax(z1, y)\n",
|
||||||
"learner = C.sgd(z1.parameters,\n",
|
"learner = C.sgd(z1.parameters,\n",
|
||||||
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))\n",
|
" C.learning_parameter_schedule(0.05))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size\n",
|
"num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -613,7 +627,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -664,6 +680,7 @@
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 13,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
|
|
@ -67,7 +67,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 30,
|
"execution_count": 30,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -130,7 +132,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Create a learner and a trainer and a progress writer to \n",
|
"# Create a learner and a trainer and a progress writer to \n",
|
||||||
"# output current progress\n",
|
"# output current progress\n",
|
||||||
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n",
|
"learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
|
||||||
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
|
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Now let's create a minibatch source for our input file\n",
|
"# Now let's create a minibatch source for our input file\n",
|
||||||
|
@ -186,7 +188,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 31,
|
"execution_count": 31,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -231,7 +235,7 @@
|
||||||
"z = model(features)\n",
|
"z = model(features)\n",
|
||||||
"loss = cntk.squared_error(z, label)\n",
|
"loss = cntk.squared_error(z, label)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n",
|
"learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
|
||||||
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
|
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Try to restore if the checkpoint exists\n",
|
"# Try to restore if the checkpoint exists\n",
|
||||||
|
@ -281,7 +285,9 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 32,
|
"execution_count": 32,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
@ -330,7 +336,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Make sure the learner is distributed\n",
|
"# Make sure the learner is distributed\n",
|
||||||
"distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n",
|
"distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n",
|
||||||
" cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n",
|
" cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1)))\n",
|
||||||
"trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n",
|
"trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if os.path.exists(checkpoint):\n",
|
"if os.path.exists(checkpoint):\n",
|
||||||
|
@ -385,6 +391,7 @@
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 33,
|
"execution_count": 33,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -441,7 +448,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"criterion = criterion_factory(features, label)\n",
|
"criterion = criterion_factory(features, label)\n",
|
||||||
"learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n",
|
"learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n",
|
||||||
" cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n",
|
" cntk.learning_parameter_schedule_per_sample(0.1)))\n",
|
||||||
"progress_writer = cntk.logging.ProgressPrinter(freq=10)\n",
|
"progress_writer = cntk.logging.ProgressPrinter(freq=10)\n",
|
||||||
"checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n",
|
"checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n",
|
||||||
"test_config = cntk.TestConfig(test_mb_source)\n",
|
"test_config = cntk.TestConfig(test_mb_source)\n",
|
||||||
|
|
Загрузка…
Ссылка в новой задаче