Merge pull request #75 from you-n-g/binding

profileing and tuning python apis.
This commit is contained in:
Xuan Hu (Sean) 2016-06-01 16:00:48 +08:00
Родитель 5c07ecd071 224bb199f3
Коммит 73424d5f3f
2 изменённых файлов: 26 добавлений и 26 удалений

Просмотреть файл

@ -46,6 +46,10 @@ class TableHandler(object):
raise NotImplementedError("You must implement the add method.")
# types
C_FLOAT_P = POINTER(c_float)
class ArrayTableHandler(TableHandler):
def __init__(self, size):
self._handler = c_void_p()
@ -56,17 +60,19 @@ class ArrayTableHandler(TableHandler):
'''
Data type of return value is numpy.ndarray
'''
c_data = (c_float * self._size)()
mv_lib.MV_GetArrayTable(self._handler, c_data, self._size)
return np.array(list(c_data), dtype=np.dtype("float32"))
data = np.zeros((self._size, ), dtype=np.dtype("float32"))
mv_lib.MV_GetArrayTable(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
return data
def add(self, data):
'''
Data type of `data` is numpy.ndarray
'''
data = np.array(data)
if not isinstance(data, np.ndarray):
data = np.array(data)
assert(data.size == self._size)
mv_lib.MV_AddArrayTable(self._handler, (c_float * self._size)(*data.reshape((-1,))), self._size)
data = data.astype(np.float32)
mv_lib.MV_AddArrayTable(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
class MatrixTableHandler(TableHandler):
@ -84,19 +90,17 @@ class MatrixTableHandler(TableHandler):
to the row_ids
'''
if row_ids is None:
float_array_type = c_float * (self._num_row * self._num_col)
c_data = float_array_type()
mv_lib.MV_GetMatrixTableAll(self._handler, c_data, self._size)
return np.array(list(c_data), dtype=np.dtype("float32")).reshape((self._num_row, self._num_col))
data = np.zeros((self._num_row, self._num_col), dtype=np.dtype("float32"))
mv_lib.MV_GetMatrixTableAll(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
return data
else:
row_ids_n = len(row_ids)
int_array_type = c_int * row_ids_n
float_array_type = c_float * (row_ids_n * self._num_col)
c_data = float_array_type()
mv_lib.MV_GetMatrixTableByRows(self._handler, c_data,
data = np.zeros((row_ids_n, self._num_col), dtype=np.dtype("float32"))
mv_lib.MV_GetMatrixTableByRows(self._handler, data.ctypes.data_as(C_FLOAT_P),
row_ids_n * self._num_col,
int_array_type(*row_ids), row_ids_n)
return np.array(list(c_data), dtype=np.dtype("float32")).reshape((row_ids_n, self._num_col))
return data
def add(self, data=None, row_ids=None):
'''
@ -106,19 +110,17 @@ class MatrixTableHandler(TableHandler):
Otherwise we will add the data according to the row_ids
'''
assert(data is not None)
data = np.array(data)
if not isinstance(data, np.ndarray):
data = np.array(data)
data = data.astype(np.float32)
if row_ids is None:
assert(data.size == self._size)
float_array_type = c_float * (self._num_row * self._num_col)
c_data = float_array_type(* data.reshape((-1, )))
mv_lib.MV_AddMatrixTableAll(self._handler, c_data, self._size)
mv_lib.MV_AddMatrixTableAll(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
else:
row_ids_n = len(row_ids)
assert(data.size == row_ids_n * self._num_col)
int_array_type = c_int * row_ids_n
float_array_type = c_float * (self._num_col * row_ids_n)
c_data = float_array_type(* data.reshape((-1, )))
mv_lib.MV_AddMatrixTableByRows(self._handler, c_data,
mv_lib.MV_AddMatrixTableByRows(self._handler, data.ctypes.data_as(C_FLOAT_P),
row_ids_n * self._num_col,
int_array_type(*row_ids), row_ids_n)

Просмотреть файл

@ -50,12 +50,10 @@ class MVNetParamManager(object):
1) calc all the delta of params in the network and add the delta to multiverso server
2) get the latest value from the multiverso server
'''
latest_network_params = []
for arr in lasagne.layers.get_all_param_values(self.network):
latest_network_params.extend([i for i in np.nditer(arr)])
latest_network_params = np.array(latest_network_params)
cur_network_params = np.concatenate([
arr.reshape(-1) for arr in lasagne.layers.get_all_param_values(self.network)])
params_delta = latest_network_params - self.all_param_list
params_delta = cur_network_params - self.all_param_list
self.tbh.add(params_delta)
self.all_param_list = self.tbh.get()
self._set_all_param_to_net()