Merge pull request #124 from you-n-g/master

Add keras extension
This commit is contained in:
you-n-g 2016-08-16 08:50:06 +08:00 коммит произвёл GitHub
Родитель 27ee8f5b73 9ba1a6ae8d
Коммит 1929cce634
7 изменённых файлов: 180 добавлений и 71 удалений

Просмотреть файл

@ -54,7 +54,7 @@ Similar strategies are already implemented in the constructors in `theano_ext.sh
## About the master worker
Some things should only be done in specific worker, such as validation, outputting the results and so on. So you can benefit from mv.is_master_worker() api to mark worker 0 as the master one to complete these tasks.
For example, if you want to make sure only one process will output the validation results, you can write similar code below.
```
```python
import multiverso as mv
# train your model
if mv.is_master_worker():
@ -72,7 +72,7 @@ First, similarly, add `mv.init()`, `mv.shutdown()` and `mv.barrier()` mentioned
In theano, parameters are usually stored in sharedVariables.
For example, sharedVariables can be created like this in a theano script.
```
```python
self.W = theano.shared(
value=numpy.zeros(
(n_in, n_out),
@ -84,7 +84,7 @@ self.W = theano.shared(
```
If you want to use multiverso, you can modify them like this.
```
```python
from multiverso.theano_ext import sharedvar
W = sharedvar.mv_shared(
value=numpy.zeros(
@ -99,12 +99,14 @@ W = sharedvar.mv_shared(
# train the model
# When you are ready to add the delta of the variable to parameter server and sync the latest value, you can run this function
# When you are ready to add the delta of the variable to parameter
# server and sync the latest value, you can run this function
W.mv_sync()
# If you want to sync all variables created by `sharedvar.mv_shared`, you can use this function.
# It will add the gradients (delta value) to the server and update the latest value from the server.
# If you want to sync all variables created by `sharedvar.mv_shared`,
# you can use this function. It will add the gradients (delta value)
# to the server and update the latest value from the server.
sharedvar.sync_all_mv_shared_vars()
```
@ -122,22 +124,50 @@ Lasagne provides many functions to build models in theano. Multiverso python bin
You can write code like this to manage your parameters.
A typical usage of managing the parameters is shown as below.
```
```python
from multiverso.theano_ext.lasagne_ext import param_manager
network = build_model() # build_model is a function you implement to build model
# The MVNetParamManager will initialize the parameters and sync them with
# The LasagneParamManager will initialize the parameters and sync them with
# parameter server
mvnpm = param_manager.MVNetParamManager(network)
lpm = param_manager.LasagneParamManager(network)
# Train the model
# When you are ready to add the delta of the variable in this model to the parameter server and get the latest value, you can run this function
mvnpm.sync_all_param()
# When you are ready to add the delta of the variable in this model to the parameter
# server and get the latest value, you can run this function
lpm.sync_all_param()
```
Detailed api documents can be found in docstring of [param_manager.py](https://github.com/Microsoft/multiverso/blob/master/binding/python/multiverso/theano_ext/lasagne_ext/param_manager.py)
Detailed api documents can be found in docstring of [param_manager.py](https://github.com/Microsoft/multiverso/blob/master/binding/python/multiverso/theano_ext/param_manager.py)
# How to use multiverso in Keras
First, add `mv.init()`, `mv.shutdown()` and `mv.barrier()` mentioned above in your codebase.
Keras provides many functions to build models. Multiverso python binding provides a callback function to make managing and synchronizing the parameters in Keras more easily.
This callback function will synchronize the parameters every mini-batch.
A typical usage of the callback function is shown as below.
```python
from multiverso.theano_ext.keras_ext.callbacks import MVCallback
model = Sequential()
# build and compile your model here
# Train the model
model.fit(X_train, Y_train,
batch_size=batch_size,
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test),
shuffle=True,
callbacks=[MVCallback(model)]) # The only difference is that you add callbacks here
```
The only difference from the normal keras program is that you add an extra callback function. This callback function will sync parameters every mini batch.
Detailed api documents can be found in docstring of [param_manager.py](https://github.com/Microsoft/multiverso/blob/master/binding/python/multiverso/theano_ext/param_manager.py) and [callbacks.py](https://github.com/Microsoft/multiverso/blob/master/binding/python/multiverso/theano_ext/keras_ext/callbacks.py)
# Run your multiverso program with 4 processes
Here is an example of running logistic regression with multi-process.
@ -155,7 +185,7 @@ Second, run the program with multiverso in multiple processes.
Here is an example to make different processes use different GPUs.
In this example, the i-th worker will use the i-th GPU. You need to add code like this before `import theano`.
```
```python
import multiverso as mv
mv.init()
worker_id = mv.worker_id()

Просмотреть файл

@ -268,9 +268,9 @@ def main(batch_size=128, lr=0.1, sync=False, n=5, num_epochs=82, model=None):
network = build_cnn(input_var, n)
print("number of parameters in model: %d" % lasagne.layers.count_params(network, trainable=True))
# MULTIVERSO: MVNetParamManager is a parameter manager which can
# MULTIVERSO: LasagneParamManager is a parameter manager which can
# synchronize parameters of Lasagne with multiverso.
mvnpm = param_manager.MVNetParamManager(network)
lpm = param_manager.LasagneParamManager(network)
if model is None:
# Create a loss expression for training, i.e., a scalar objective we want
@ -328,10 +328,10 @@ def main(batch_size=128, lr=0.1, sync=False, n=5, num_epochs=82, model=None):
inputs, targets = batch
train_err += train_fn(inputs, targets)
# MULTIVERSO: when you want to commit all the delta of
# parameters manage by MVNetParamManager and update the latest
# parameters manage by LasagneParamManager and update the latest
# parameters from parameter server, you can call this function to
# synchronize the values
mvnpm.sync_all_param()
lpm.sync_all_param()
# And a full pass over the validation data:
# MULTIVERSO: all the workers will synchronize at the place you call barrier
@ -368,7 +368,7 @@ def main(batch_size=128, lr=0.1, sync=False, n=5, num_epochs=82, model=None):
mv.barrier()
if mv.is_master_worker():
# MULTIVERSO: update the parameters before save the model
mvnpm.sync_all_param()
lpm.sync_all_param()
# dump the network weights to a file :
np.savez('cifar10_deep_residual_model.npz', *lasagne.layers.get_all_param_values(network))
else:

Просмотреть файл

Просмотреть файл

@ -0,0 +1,27 @@
#!/usr/bin/env python
# coding:utf8
from keras.callbacks import Callback
from param_manager import KerasParamManager
class MVCallback(Callback):
'''
Please use MVCallback as a callback of keras model.fit function
For e.g.
```
model.fit(X_train, Y_train,
batch_size=batch_size,
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test),
shuffle=True,
callbacks=[mvcallback(model)])
```
'''
def __init__(self, model):
super(MVCallback, self).__init__()
self.kpm = KerasParamManager(model)
def on_batch_end(self, batch, logs={}):
'''sync all parameters at the end of every batch'''
self.kpm.sync_all_param()

Просмотреть файл

@ -0,0 +1,16 @@
#!/usr/bin/env python
from ..param_manager import MVModelParamManager
class KerasParamManager(MVModelParamManager):
'''
KerasParamManager is manager to make managing and synchronizing the
variables in keras more easily
'''
def get_all_param_values(self):
return self.model.get_weights()
def set_all_param_values(self, params):
self.model.set_weights(params)

Просмотреть файл

@ -2,63 +2,17 @@
# coding:utf8
import lasagne
import numpy as np
import multiverso as mv
from ..param_manager import MVModelParamManager
class MVNetParamManager(object):
class LasagneParamManager(MVModelParamManager):
'''
MVNetParamManager is manager to make managing and synchronizing the
LasagneParamManager is manager to make managing and synchronizing the
variables in lasagne more easily
'''
def __init__(self, network):
''' The constructor of MVNetParamManager
The constructor will associate the parameter with multiverso array
table. The initial value of ArrayTableHandler will be same as the
parameters of network. If different parameters are used in different
processes, the average of them will be used as the initial value
'''
self.shapes = []
self.dtypes = []
self.sizes = []
self.all_param_list = []
self.network = network
for arr in lasagne.layers.get_all_param_values(self.network):
self.shapes.append(arr.shape)
# TODO: Now only float32 is supported in multiverso. So I store all
# the parameters in a float32 array. This place need modification
# after other types are supported
assert(np.dtype("float32") == arr.dtype)
self.dtypes.append(arr.dtype)
self.sizes.append(arr.size)
self.all_param_list.extend([i for i in np.nditer(arr)])
self.all_param_list = np.array(self.all_param_list)
def get_all_param_values(self):
return lasagne.layers.get_all_param_values(self.model)
self.tbh = mv.ArrayTableHandler(len(self.all_param_list), init_value=self.all_param_list)
mv.barrier() # add barrier to make sure the initial values have token effect
self.all_param_list = self.tbh.get()
self._set_all_param_to_net()
def _set_all_param_to_net(self):
n = 0
params = []
for i, size in enumerate(self.sizes):
params.append(self.all_param_list[n:n + size].reshape(self.shapes[i]))
n += size
lasagne.layers.set_all_param_values(self.network, params)
def sync_all_param(self):
'''sync all parameters with multiverso server
This function will
1) calc all the delta of params in the network and add the delta to multiverso server
2) get the latest value from the multiverso server
'''
cur_network_params = np.concatenate([
arr.reshape(-1) for arr in lasagne.layers.get_all_param_values(self.network)])
params_delta = cur_network_params - self.all_param_list
self.tbh.add(params_delta)
self.all_param_list = self.tbh.get()
self._set_all_param_to_net()
def set_all_param_values(self, params):
lasagne.layers.set_all_param_values(self.model, params)

Просмотреть файл

@ -0,0 +1,82 @@
#!/usr/bin/env python
# coding:utf8
import lasagne
import numpy as np
import multiverso as mv
class MVModelParamManager(object):
'''
MVModelParamManager is manager to make managing and synchronizing the
variables in lasagne more easily
'''
def __init__(self, model):
''' The constructor of MVModelParamManager
The constructor will associate the parameter with multiverso array
table. The initial value of ArrayTableHandler will be same as the
parameters of model. If different parameters are used in different
processes, the average of them will be used as the initial value
'''
self.shapes = []
self.dtypes = []
self.sizes = []
self.all_param_list = []
self.model = model
for arr in self.get_all_param_values():
self.shapes.append(arr.shape)
# TODO: Now only float32 is supported in multiverso. So I store all
# the parameters in a float32 array. This place need modification
# after other types are supported
assert(np.dtype("float32") == arr.dtype)
self.dtypes.append(arr.dtype)
self.sizes.append(arr.size)
self.all_param_list.extend([i for i in np.nditer(arr)])
self.all_param_list = np.array(self.all_param_list)
self.tbh = mv.ArrayTableHandler(len(self.all_param_list), init_value=self.all_param_list)
mv.barrier() # add barrier to make sure the initial values have token effect
self.all_param_list = self.tbh.get()
self._set_all_param_to_model()
def get_all_param_values(self):
'''Get all param values of specific model
Gets the parameters of the model. It should return a list of Numpy
arrays with shapes and types matching the output of
`set_all_param_values()`.
'''
raise NotImplemented()
def set_all_param_values(self, params):
'''Set all param values of specific model
Sets the parameters of the model. The `params` argument should be a
list of Numpy arrays with shapes and types matching the output of
`get_all_param_values()`.
'''
raise NotImplemented()
def _set_all_param_to_model(self):
n = 0
params = []
for i, size in enumerate(self.sizes):
params.append(self.all_param_list[n:n + size].reshape(self.shapes[i]))
n += size
self.set_all_param_values(params)
def sync_all_param(self):
'''sync all parameters with multiverso server
This function will
1) calc all the delta of params in the model and add the delta to multiverso server
2) get the latest value from the multiverso server
'''
cur_model_params = np.concatenate([
arr.reshape(-1) for arr in self.get_all_param_values()])
params_delta = cur_model_params - self.all_param_list
self.tbh.add(params_delta)
self.all_param_list = self.tbh.get()
self._set_all_param_to_model()