468 строки
25 KiB
Plaintext
468 строки
25 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Manual: How to create user minibatch sources\n",
|
||
"\n",
|
||
"In order to make use of CNTK’s (distributed) training functionality, one has to provide input data as an instance of [MinibatchSource](https://cntk.ai/pythondocs/cntk.io.html#cntk.io.MinibatchSource). In CNTK, there are a variety of means to provide minibatch sources:\n",
|
||
"\n",
|
||
"- (**best**) convert data to the formats of built-in data readers - they support rich functionality of randomization/packing with high performance (see [Manual: How to feed data](https://github.com/Microsoft/CNTK/blob/master/Manual/Manual_How_to_feed_data.ipynb) and [cntk.io](https://cntk.ai/pythondocs/cntk.io.html))\n",
|
||
"- (**preferred**) if it is hard to convert the data and the data can fit in memory, please use [MinibatchSourceFromData](https://cntk.ai/pythondocs/cntk.io.html?highlight=minibatchsourcefromdata#cntk.io.MinibatchSourceFromData), \n",
|
||
"- if the data does not fit in memory and you want a fine grained control over how minibatch is created, then implementing the abstract [UserMinibatchSource](https://cntk.ai/pythondocs/cntk.io.html#cntk.io.UserMinibatchSource) interface is the option. \n",
|
||
"\n",
|
||
"This manual explains the last approach: How to create user minibatch source in Python.\n",
|
||
"\n",
|
||
"\n",
|
||
"## User minibatch sources\n",
|
||
"\n",
|
||
"A minibatch source is responsible for providing:\n",
|
||
"1. meta-information regarding the data, such as *storage format*, *data type*, *shape of elements*,\n",
|
||
"2. batches of data, and\n",
|
||
"3. auxiliary information for advanced features, such as checkpoint state of the current data access position so that interrupted learning processes can be restored from the data position where the processes were interrupted.\n",
|
||
"\n",
|
||
"Correspondingly, a minibatch source API needs to implement the following $4$ methods (see [UserMinibatchSource](https://cntk.ai/pythondocs/cntk.io.html?highlight=userminibatch#cntk.io.UserMinibatchSource) for details):\n",
|
||
"1. **stream_infos()**: Returns a list of StreamInformation instances. Each piece of stream information contains the meta information regarding a stream of the data: e.g. storage format, data type, shape of elements (see [StreamInformation](https://cntk.ai/pythondocs/cntk.io.html#cntk.io.StreamInformation) for details) \n",
|
||
"2. **next_minibatch(num_samples, number_of_workers, worker_rank, device=None)**: Returns next minibatch of data of the specified nature as specified by given parameters:\n",
|
||
" * num_samples: the number of samples that are being requested \n",
|
||
" * num_of_workders: the number of workers in a distributed training session; if it is not in a distributed training setting, this number is always 1\n",
|
||
" * worker_rank: the number which identifies the specific worker who requests this minibatch in the distributed training setting; if this is not a distributed training session, the worker rank is always 0 (the first worker)\n",
|
||
" * device: a device descriptor specifying which device the minibatch data should be copied to, e.g. cntk.device.cpu() or cntk.device.gpu(device_id) (see [DeviceDescriptor](https://cntk.ai/pythondocs/cntk.device.html#cntk.device.DeviceDescriptor) for details)\n",
|
||
"3. **get_checkpoint_state()**: Returns a dictionary which describe the current state of the minibatch source\n",
|
||
"4. **restore_from_checkpoint(state)**: Sets the state of the minibatch source according to a checkint state object. This allows a minibatch source restoring data feeding from position where the checkpoint was saved. Note that *state* is the dictionary returned by get_checkpoint_state(). \n",
|
||
"\n",
|
||
"\n",
|
||
"Now let's go through th implementation of these $4$ methods step by step.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"## User minibatch source step by step"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In the following example, we will detail the steps on how to implement the [UserMinibatchSource](https://cntk.ai/pythondocs/cntk.io.html#cntk.io.UserMinibatchSource) interface. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"First, let's import the necessary packages:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import cntk as C\n",
|
||
"from cntk.io import UserMinibatchSource, StreamInformation, MinibatchData"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Secondly, let's assume that we have a data set in the following tab seperated text format:\n",
|
||
"* The first column is the sequence ID: e.g. 0 is the ID for sequecne 0; and 1 is the ID for sequence 1.\n",
|
||
"* The second column starts with symbol \"|\". It is the feature named 'x' which is a sparse representation for the words in our training data.\n",
|
||
"* The third column again starts with symbol \"|\". It is our lable named 'y' which is the one-hot representation of label.\n",
|
||
"\n",
|
||
"In the fllowing, our toy data set contains 4 sequences:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data = r'''0 |x 560:1 |y 1 0 0 0 0\n",
|
||
"0 |x 0:1\n",
|
||
"0 |x 0:1\n",
|
||
"1 |x 560:1 |y 0 1 0 0 0\n",
|
||
"1 |x 0:1\n",
|
||
"1 |x 0:1\n",
|
||
"1 |x 424:1\n",
|
||
"2 |x 160:1 |y 0 0 1 0 0\n",
|
||
"2 |x 5:1\n",
|
||
"2 |x 6:1\n",
|
||
"3 |x 460:1 |y 0 0 0 1 0\n",
|
||
"3 |x 3:1\n",
|
||
"3 |x 3:1\n",
|
||
"3 |x 425:1\n",
|
||
"'''"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Inherit *UserMinibatchSource* to create your user minibath class: \n",
|
||
"\n",
|
||
"To implement our example user minibatch source, we first prepare the data access and its meta information: \n",
|
||
"1. Parse the text formatted data into an intermediate representation so that we can access the data by their sequence indices: \n",
|
||
"```\n",
|
||
" features = self.data[seq_idx]['features']\n",
|
||
" labels = self.data[seq_idx]['labels']\n",
|
||
"```\n",
|
||
" This is done by create a private method *_prepare_data()* in the example below. We ommit the implementation detail of text format parsing here as the detail is irrelevant to the understanding of the UserMinibatchSource interface. However, the parsing mechanims should be able to keep track of where the current data access point is so that the data feeding process can be restored at any point. In the example, we are tracking the sequence index. \n",
|
||
"2. Define the meta information of the data: e.g.\n",
|
||
"```\n",
|
||
" self.fsi = StreamInformation(\"features\", 0, 'sparse', np.float32, (self.f_dim,))\n",
|
||
" self.lsi = StreamInformation(\"labels\", 1, 'dense', np.float32, (self.l_dim,))\n",
|
||
"```\n",
|
||
" The self.fsi and self.lsi define the meta information (see [StreamInformation](https://cntk.ai/pythondocs/cntk.io.html#cntk.io.StreamInformation) for definition ) regarding the features and labels respectively. For example, StreamInformation(\"features\", 0, 'sparse', np.float32, (self.f_dim,)) specifies that 1) the \"feature\" data stream is indentified by ID $0$ (it is required that every data stream is identified by a unique ID), 2) it is sparse, 3) its data type is np.float32, and 4) its dimension is (self.f_dim, ).\n",
|
||
"3. Set the initial states of the data source. For example, set the next sequence index to the beginning:\n",
|
||
"```\n",
|
||
" self.next_seq_idx = 0\n",
|
||
"```\n",
|
||
"4. Finally, create your minibatch class based on **UserMinibatchSource** and put the above data access preparation steps in its constructor: \n",
|
||
"\n",
|
||
"```python\n",
|
||
"class MyMultiWorkerDataSource(UserMinibatchSource):\n",
|
||
" def __init__(self, f_dim, l_dim):\n",
|
||
" self.f_dim, self.l_dim = f_dim, l_dim\n",
|
||
" self._prepare_data()\n",
|
||
" #setting the state\n",
|
||
" self.fsi = StreamInformation(\"features\", 0, 'sparse', np.float32, (self.f_dim,))\n",
|
||
" self.lsi = StreamInformation(\"labels\", 1, 'dense', np.float32, (self.l_dim,))\n",
|
||
" self.sequences = sorted(self.data)\n",
|
||
" self.next_seq_idx = 0 \n",
|
||
" super(MyMultiWorkerDataSource, self).__init__()\n",
|
||
"```\n",
|
||
"Do not forget to call the super class' constructor: **super(MyMultiWorkerDataSource, self).__init__()**."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Override *stream_infos()* method to provide meta-informatoin of data: \n",
|
||
"\n",
|
||
"After the preparation is done by the constructor, we can implement *stream_infos()* simply by returning the list of stream information instances: \n",
|
||
"```python\n",
|
||
" def stream_infos(self):\n",
|
||
" return [self.fsi, self.lsi]\n",
|
||
"```\n",
|
||
" With this method implemented, the underlying minibatch source framework will able to refer to the meta information by names \"featuers\" and \"labels\"."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Override *next_minibatch()* method to provide data\n",
|
||
"Let us first review the function signature of the next_minibatch method:\n",
|
||
"```python\n",
|
||
" def next_minibatch(self, num_samples, number_of_workers, worker_rank, device)\n",
|
||
"```\n",
|
||
"This method is invoked by the outer CNTK learning loops with four parameters: \n",
|
||
"* the nubmer of samples needed, \n",
|
||
"* number of workers, \n",
|
||
"* worker rank (i.e. worker ID), and \n",
|
||
"* the device on which the data should be copied to. \n",
|
||
"\n",
|
||
"In other words, it is the user minibatch source's responsibility to understand these parameters and provide minibatch data accordingly. The minibatch source need to ensure that \n",
|
||
"* the returned minibatch contains the specified number of samples or less, \n",
|
||
"* the returned minbiatch contains only the data that are supposed to be assigned to the specified worker (identified by the worker_rank) - it is the user minibanch's responsisbility to ensure that the data load of these workers are balanced in certain manner, and \n",
|
||
"* the data are ready in the specified device (e.g. CPU or GPU). \n",
|
||
"\n",
|
||
"To make the underlying requirement stand out, in the example below we implemented a private function *_prepare_nextbatch()* to encapsulate details:\n",
|
||
"```python\n",
|
||
"def _prepare_nextbatch(self, num_samples, number_of_workers, worker_rank):\n",
|
||
" # details....\n",
|
||
" return features, f_sample_count, labels, l_sample_count, sweep_end\n",
|
||
"```\n",
|
||
"This function ensure that *features* and *labels* contains the *num_samples* of samples or less. The sample counts are also returned as *f_sample_count* and *l_sample_count* respectively. Note that different data streams might contain different number of samples. In addition, *sweep_end* tells whether this minibatch is at the end of a sweep of the whole data set. \n",
|
||
"\n",
|
||
"To define user minibatch source that can be used with distributed learners, e.g. BlockMomentum. We will need to use number_of_workers to cut the data into slices and then return the slices depending on which worker_rank requested the next minibatch. In this private function, we implement a naive logic to distribute the data to the specific worker by skipping sequences if its sequence index modulo the number of workers does not equal to the worker rank:\n",
|
||
"```python\n",
|
||
" if (seq_id % number_of_workers) != worker_rank:\n",
|
||
" continue\n",
|
||
"```\n",
|
||
"This is only for demonstration purpose. In practice, the distribution of data to workers should be based on a more efficient mechanism: e.g. based on how costly the specific worker can access the specific subset of data and the randomization mechanism.\n",
|
||
"\n",
|
||
"After the data is prepared, we need to convert them into the values that CNTK operators can operate on efficiently. This is done by create various types of cntk.Value instances:\n",
|
||
"```python\n",
|
||
" feature_data = C.Value.one_hot(batch=features, num_classes=self.f_dim, device = device)\n",
|
||
" label_data = C.Value(batch=np.asarray(labels, dtype=np.float32), device = device)\n",
|
||
"```\n",
|
||
"In this example, the feature data are of a special type of sparse data which are created through the [cntk.Value.one_hot](https://cntk.ai/pythondocs/cntk.core.html?highlight=one_hot#cntk.core.Value.one_hot) function --- an element within a sequence is a one-hot vector. The label data are of a type of dense data which are created through the [cntk.Value](https://cntk.ai/pythondocs/cntk.core.html?highlight=value#cntk.core.Value) constructor. Note that in these CNTK value constructors, we explicitly specify on which device these values should be constructed. Reall that the *device* parameter is provided by the outher learning loops. \n",
|
||
"\n",
|
||
"Finally, we need to create [MinibatchData](https://cntk.ai/pythondocs/cntk.io.html?highlight=minibatchdata#cntk.io.MinibatchData) instances and return them in a dictionary with the corresponding [StreamInformation](https://cntk.ai/pythondocs/cntk.io.html#cntk.io.StreamInformation) instances as keys:\n",
|
||
"```python\n",
|
||
" res = {\n",
|
||
" self.fsi: MinibatchData(feature_data, num_seq, feature_sample_count, sweep_end),\n",
|
||
" self.lsi: MinibatchData(label_data, num_seq, label_sample_count, sweep_end)}\n",
|
||
" return res\n",
|
||
"```\n",
|
||
"The constructor of *MinibatchData* takes 1) the data that are already in the form [cntk.Value](https://cntk.ai/pythondocs/cntk.core.html?highlight=value#cntk.core.Value): i.e. feature_data and label_data here, 2) the number of sequences in the minibatch, 3) the number of samples, and 4) whether it is at the end of a sweep of the whole data set. \n",
|
||
"\n",
|
||
"All together, we've implemented our *next_minibatch()* method to provide minibatches of data of specified properties for the outer learning loops to consume."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Override *get_checkpoint_state()* and *restore_from_checkpoint()* methods to provide checkpoint state and restore from it\n",
|
||
"\n",
|
||
"Firstly, we need to define the state of our user minibatch so that the data feeding process can be restored from the exact point where it was stopped. In our simple example, we just need to know to next sequence index to restore the data feeding process by the following *get* and *restore* checkpoints methods:\n",
|
||
"```python\n",
|
||
" def get_checkpoint_state(self):\n",
|
||
" return {'next_seq_idx': self.next_seq_idx}\n",
|
||
"```\n",
|
||
"```python\n",
|
||
" def restore_from_checkpoint(self, state):\n",
|
||
" self.next_seq_idx = state['next_seq_idx']\n",
|
||
"```\n",
|
||
"It is easy to see that a checkpoint state is a dictionary from string keys to the corresponding state variable value objects. In this example, it is the next sequence index."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### The complete user minibatch example \n",
|
||
"\n",
|
||
"All together we have our complete user minibatch implementation as follows: "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class MyMultiWorkerDataSource(UserMinibatchSource):\n",
|
||
" def __init__(self, f_dim, l_dim):\n",
|
||
" self.f_dim, self.l_dim = f_dim, l_dim\n",
|
||
" self._prepare_data()\n",
|
||
" #setting the state\n",
|
||
" self.fsi = StreamInformation(\"features\", 0, 'sparse', np.float32, (self.f_dim,))\n",
|
||
" self.lsi = StreamInformation(\"labels\", 1, 'dense', np.float32, (self.l_dim,))\n",
|
||
" self.sequences = sorted(self.data)\n",
|
||
" self.next_seq_idx = 0 \n",
|
||
" super(MyMultiWorkerDataSource, self).__init__()\n",
|
||
"\n",
|
||
" def _prepare_data(self):\n",
|
||
" \"\"\"\n",
|
||
" Parse the text and load the data into self.data. \n",
|
||
" self.data is of the following structure:\n",
|
||
" sequence id -> \"features\" -> list of features\n",
|
||
" and\n",
|
||
" sequence id -> \"labels\" -> label\n",
|
||
" \"\"\"\n",
|
||
" self.data = {}\n",
|
||
" for line in sample_data.split('\\n'):\n",
|
||
" line = line.strip()\n",
|
||
" if not line:\n",
|
||
" continue\n",
|
||
" seq_id, data = line.split('|', 1)\n",
|
||
" data = data.split(\"|\")\n",
|
||
" seq_id = int(seq_id.strip())\n",
|
||
"\n",
|
||
" if seq_id not in self.data:\n",
|
||
" self.data[seq_id] = {'features': []}\n",
|
||
"\n",
|
||
" # Processing features - expecting one per line.\n",
|
||
" features = data[0].split(\" \")\n",
|
||
" vocab_idx = int(features[1].split(\":\")[0])\n",
|
||
" self.data[seq_id]['features'].append(vocab_idx)\n",
|
||
"\n",
|
||
" # Process label, if exists\n",
|
||
" if len(data) == 2:\n",
|
||
" labels = np.asarray([data[1].split(\" \")[1:]], dtype=np.float32)\n",
|
||
" self.data[seq_id]['labels'] = labels\n",
|
||
" \n",
|
||
" def _prepare_nextbatch(self, num_samples, number_of_workers, worker_rank):\n",
|
||
" features = []\n",
|
||
" labels = []\n",
|
||
" sweep_end = False\n",
|
||
" f_sample_count = l_sample_count = 0\n",
|
||
"\n",
|
||
" while max(f_sample_count, l_sample_count) < num_samples:\n",
|
||
" if self.next_seq_idx == len(self.sequences):\n",
|
||
" sweep_end = True\n",
|
||
" self.next_seq_idx = 0\n",
|
||
"\n",
|
||
" seq_id = self.sequences[self.sequences[self.next_seq_idx]]\n",
|
||
" #Based on the worker rank, determines whether to add this \n",
|
||
" #data in the batch: If the sequences doesn't belong to this \n",
|
||
" #worker, skip it. In practice, this should be based on more \n",
|
||
" #efficient mechanism: e.g. based on the location of the worker\n",
|
||
" #and the data location\n",
|
||
" if (seq_id % number_of_workers) != worker_rank:\n",
|
||
" continue\n",
|
||
" feature_data = self.data[seq_id]['features']\n",
|
||
" label_data = self.data[seq_id]['labels']\n",
|
||
" if (features or labels) and \\\n",
|
||
" max(f_sample_count+len(feature_data), \\\n",
|
||
" l_sample_count+len(label_data)) > num_samples:\n",
|
||
" break\n",
|
||
" f_sample_count += len(feature_data)\n",
|
||
" features.append(feature_data)\n",
|
||
"\n",
|
||
" l_sample_count += len(label_data)\n",
|
||
" labels.append(label_data)\n",
|
||
"\n",
|
||
" self.next_seq_idx += 1\n",
|
||
" return features, f_sample_count, labels, l_sample_count, sweep_end\n",
|
||
" \n",
|
||
" def stream_infos(self):\n",
|
||
" \"\"\"\n",
|
||
" Override the stream_infos method of the base UserMinibatchSource class\n",
|
||
" to provide stream meta information.\n",
|
||
" \"\"\"\n",
|
||
" return [self.fsi, self.lsi]\n",
|
||
"\n",
|
||
" def next_minibatch(self, num_samples, number_of_workers, worker_rank, device):\n",
|
||
" \"\"\"\n",
|
||
" Override the next_minibatch method of the base UserMinibatchSource class \n",
|
||
" to provide minibatch data.\n",
|
||
" \"\"\"\n",
|
||
" features, feature_sample_count, \\\n",
|
||
" labels, label_sample_count, sweep_end = self._prepare_nextbatch(num_samples,\n",
|
||
" number_of_workers,\n",
|
||
" worker_rank)\n",
|
||
" feature_data = C.Value.one_hot(batch=features, num_classes=self.f_dim, device = device)\n",
|
||
" label_data = C.Value(batch=np.asarray(labels, dtype=np.float32), device = device)\n",
|
||
" num_seq = len(features)\n",
|
||
" res = {\n",
|
||
" self.fsi: MinibatchData(feature_data, num_seq, feature_sample_count, sweep_end),\n",
|
||
" self.lsi: MinibatchData(label_data, num_seq, label_sample_count, sweep_end)\n",
|
||
" }\n",
|
||
" return res\n",
|
||
"\n",
|
||
" def get_checkpoint_state(self):\n",
|
||
" return {'next_seq_idx': self.next_seq_idx}\n",
|
||
" \n",
|
||
" def restore_from_checkpoint(self, state):\n",
|
||
" self.next_seq_idx = state['next_seq_idx']\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Note that in this example, for simplicity we load the whole data set into the memory. In practice, the minibatch source should depend on the data source state (e.g. the mapping between the requesting next batch data and its logical/physical location in the data storage) to load (or pre-load) the data at the point (or right before) they are requested.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Using the user minibatch data source in training sessions with distributed learners\n",
|
||
"The implemented minitbatch source can then be used wherever a MinibatchSource instance is accepted. For example, "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Finished Epoch[1 of 10]: [Training] loss = 1.564692 * 20, metric = 90.00% * 20 0.600s ( 33.3 samples/s);\n",
|
||
"Finished Epoch[2 of 10]: [Training] loss = 1.035505 * 20, metric = 5.00% * 20 0.045s (444.4 samples/s);\n",
|
||
"Finished Epoch[3 of 10]: [Training] loss = 0.364257 * 20, metric = 0.00% * 20 0.043s (465.1 samples/s);\n",
|
||
"Finished Epoch[4 of 10]: [Training] loss = 0.133813 * 20, metric = 0.00% * 20 0.052s (384.6 samples/s);\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"input_dim = 1000\n",
|
||
"num_output_classes = 5\n",
|
||
"\n",
|
||
"# instantiating the user minibatch source\n",
|
||
"mbs = MyMultiWorkerDataSource( input_dim, num_output_classes)\n",
|
||
"feature = C.sequence.input_variable(shape=(input_dim,))\n",
|
||
"label = C.input_variable(shape=(num_output_classes,))\n",
|
||
"\n",
|
||
"# setting up the model\n",
|
||
"rnn = C.layers.Recurrence(C.layers.LSTM(20), go_backwards=False)(feature)\n",
|
||
"end = C.sequence.last(rnn)\n",
|
||
"z = C.layers.Dense(num_output_classes)(end)\n",
|
||
"loss = C.cross_entropy_with_softmax(z, label)\n",
|
||
"errs = C.classification_error(z, label)\n",
|
||
"local_learner = C.sgd(z.parameters, \n",
|
||
" C.learning_rate_schedule(0.5, unit = C.UnitType.sample))\n",
|
||
"dist_learner = C.distributed.data_parallel_distributed_learner(local_learner)\n",
|
||
"# and train\n",
|
||
"trainer = C.Trainer(z, (loss, errs), \n",
|
||
" [dist_learner], \n",
|
||
" [C.logging.ProgressPrinter(tag='Training', num_epochs=10)])\n",
|
||
"input_map = {\n",
|
||
" feature: mbs.fsi,\n",
|
||
" label: mbs.lsi\n",
|
||
"}\n",
|
||
"session = C.training_session(\n",
|
||
" trainer = trainer, \n",
|
||
" mb_source = mbs,\n",
|
||
" model_inputs_to_streams = input_map,\n",
|
||
" mb_size = 7, \n",
|
||
" max_samples = 80,\n",
|
||
" progress_frequency = 20\n",
|
||
")\n",
|
||
"session.train()\n",
|
||
"#finalize the distributed learning\n",
|
||
"C.distributed.Communicator.finalize()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## User minibatch sources in restricted scenarios\n",
|
||
"\n",
|
||
"In certain simplified scenarios, we might not want to implement a minibatch source with full functionality. \n",
|
||
"\n",
|
||
"* If parallel data learning is not reqired, we can omit the logic of distributing data to workers. Set number_of_workers = 1 and worker_rank = 0 when overriding the *next_minibatch()* method.\n",
|
||
"\n",
|
||
"* If checkpoint restoration is not require, we can omit implementing the two checkpoint related methods: *get_checkpoint_state()* and *restore_from_checkpoint()*.\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.5.3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|