Updated manuals with new learning APIs

This commit is contained in:
yuqtang 2017-11-12 10:32:37 -08:00
Родитель 9e549e7258
Коммит b0df0a13e6
3 изменённых файлов: 48 добавлений и 22 удалений

Просмотреть файл

@ -369,7 +369,7 @@
" return {'next_seq_idx': self.next_seq_idx}\n",
" \n",
" def restore_from_checkpoint(self, state):\n",
" self.next_seq_idx = state['next_seq_idx']\n"
" self.next_seq_idx = state['next_seq_idx']"
]
},
{
@ -390,7 +390,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -419,7 +421,7 @@
"loss = C.cross_entropy_with_softmax(z, label)\n",
"errs = C.classification_error(z, label)\n",
"local_learner = C.sgd(z.parameters, \n",
" C.learning_rate_schedule(0.5, unit = C.UnitType.sample))\n",
" C.learning_parameter_schedule_per_sample(0.5))\n",
"dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n",
"# and train\n",
"trainer = C.Trainer(z, (loss, errs), \n",
@ -476,5 +478,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 0
}

Просмотреть файл

@ -67,7 +67,9 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -120,7 +122,9 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -149,7 +153,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -173,7 +179,7 @@
],
"source": [
"learner = C.sgd(model.parameters,\n",
" C.learning_rate_schedule(0.1, C.UnitType.minibatch))\n",
" C.learning_parameter_schedule(0.1))\n",
"progress_writer = C.logging.ProgressPrinter(0)\n",
"\n",
"train_summary = loss.train((X_train_lr, Y_train_lr), parameter_learners=[learner],\n",
@ -190,7 +196,9 @@
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
@ -271,7 +279,9 @@
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -348,7 +358,7 @@
"z = model(x/255.0) #scale the input to 0-1 range\n",
"loss = C.cross_entropy_with_softmax(z, y)\n",
"learner = C.sgd(z.parameters,\n",
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))"
" C.learning_parameter_schedule(0.05))"
]
},
{
@ -365,7 +375,9 @@
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -423,7 +435,9 @@
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -453,7 +467,7 @@
"z1 = model(x/255.0) #scale the input to 0-1 range\n",
"loss = C.cross_entropy_with_softmax(z1, y)\n",
"learner = C.sgd(z1.parameters,\n",
" C.learning_rate_schedule(0.05, C.UnitType.minibatch))\n",
" C.learning_parameter_schedule(0.05))\n",
"\n",
"num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size\n",
"\n",
@ -613,7 +627,9 @@
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -664,6 +680,7 @@
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [

Просмотреть файл

@ -67,7 +67,9 @@
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -130,7 +132,7 @@
"\n",
"# Create a learner and a trainer and a progress writer to \n",
"# output current progress\n",
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n",
"learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
"\n",
"# Now let's create a minibatch source for our input file\n",
@ -186,7 +188,9 @@
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -231,7 +235,7 @@
"z = model(features)\n",
"loss = cntk.squared_error(z, label)\n",
"\n",
"learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))\n",
"learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n",
"trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n",
"\n",
"# Try to restore if the checkpoint exists\n",
@ -281,7 +285,9 @@
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
@ -330,7 +336,7 @@
"\n",
"# Make sure the learner is distributed\n",
"distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n",
" cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n",
" cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1)))\n",
"trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n",
"\n",
"if os.path.exists(checkpoint):\n",
@ -385,6 +391,7 @@
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
@ -441,7 +448,7 @@
"\n",
"criterion = criterion_factory(features, label)\n",
"learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n",
" cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))\n",
" cntk.learning_parameter_schedule_per_sample(0.1)))\n",
"progress_writer = cntk.logging.ProgressPrinter(freq=10)\n",
"checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n",
"test_config = cntk.TestConfig(test_mb_source)\n",