Merge pull request #75 from you-n-g/binding
profileing and tuning python apis.
This commit is contained in:
Коммит
73424d5f3f
|
@ -46,6 +46,10 @@ class TableHandler(object):
|
|||
raise NotImplementedError("You must implement the add method.")
|
||||
|
||||
|
||||
# types
|
||||
C_FLOAT_P = POINTER(c_float)
|
||||
|
||||
|
||||
class ArrayTableHandler(TableHandler):
|
||||
def __init__(self, size):
|
||||
self._handler = c_void_p()
|
||||
|
@ -56,17 +60,19 @@ class ArrayTableHandler(TableHandler):
|
|||
'''
|
||||
Data type of return value is numpy.ndarray
|
||||
'''
|
||||
c_data = (c_float * self._size)()
|
||||
mv_lib.MV_GetArrayTable(self._handler, c_data, self._size)
|
||||
return np.array(list(c_data), dtype=np.dtype("float32"))
|
||||
data = np.zeros((self._size, ), dtype=np.dtype("float32"))
|
||||
mv_lib.MV_GetArrayTable(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
|
||||
return data
|
||||
|
||||
def add(self, data):
|
||||
'''
|
||||
Data type of `data` is numpy.ndarray
|
||||
'''
|
||||
data = np.array(data)
|
||||
if not isinstance(data, np.ndarray):
|
||||
data = np.array(data)
|
||||
assert(data.size == self._size)
|
||||
mv_lib.MV_AddArrayTable(self._handler, (c_float * self._size)(*data.reshape((-1,))), self._size)
|
||||
data = data.astype(np.float32)
|
||||
mv_lib.MV_AddArrayTable(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
|
||||
|
||||
|
||||
class MatrixTableHandler(TableHandler):
|
||||
|
@ -84,19 +90,17 @@ class MatrixTableHandler(TableHandler):
|
|||
to the row_ids
|
||||
'''
|
||||
if row_ids is None:
|
||||
float_array_type = c_float * (self._num_row * self._num_col)
|
||||
c_data = float_array_type()
|
||||
mv_lib.MV_GetMatrixTableAll(self._handler, c_data, self._size)
|
||||
return np.array(list(c_data), dtype=np.dtype("float32")).reshape((self._num_row, self._num_col))
|
||||
data = np.zeros((self._num_row, self._num_col), dtype=np.dtype("float32"))
|
||||
mv_lib.MV_GetMatrixTableAll(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
|
||||
return data
|
||||
else:
|
||||
row_ids_n = len(row_ids)
|
||||
int_array_type = c_int * row_ids_n
|
||||
float_array_type = c_float * (row_ids_n * self._num_col)
|
||||
c_data = float_array_type()
|
||||
mv_lib.MV_GetMatrixTableByRows(self._handler, c_data,
|
||||
data = np.zeros((row_ids_n, self._num_col), dtype=np.dtype("float32"))
|
||||
mv_lib.MV_GetMatrixTableByRows(self._handler, data.ctypes.data_as(C_FLOAT_P),
|
||||
row_ids_n * self._num_col,
|
||||
int_array_type(*row_ids), row_ids_n)
|
||||
return np.array(list(c_data), dtype=np.dtype("float32")).reshape((row_ids_n, self._num_col))
|
||||
return data
|
||||
|
||||
def add(self, data=None, row_ids=None):
|
||||
'''
|
||||
|
@ -106,19 +110,17 @@ class MatrixTableHandler(TableHandler):
|
|||
Otherwise we will add the data according to the row_ids
|
||||
'''
|
||||
assert(data is not None)
|
||||
data = np.array(data)
|
||||
if not isinstance(data, np.ndarray):
|
||||
data = np.array(data)
|
||||
data = data.astype(np.float32)
|
||||
|
||||
if row_ids is None:
|
||||
assert(data.size == self._size)
|
||||
float_array_type = c_float * (self._num_row * self._num_col)
|
||||
c_data = float_array_type(* data.reshape((-1, )))
|
||||
mv_lib.MV_AddMatrixTableAll(self._handler, c_data, self._size)
|
||||
mv_lib.MV_AddMatrixTableAll(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size)
|
||||
else:
|
||||
row_ids_n = len(row_ids)
|
||||
assert(data.size == row_ids_n * self._num_col)
|
||||
int_array_type = c_int * row_ids_n
|
||||
float_array_type = c_float * (self._num_col * row_ids_n)
|
||||
|
||||
c_data = float_array_type(* data.reshape((-1, )))
|
||||
mv_lib.MV_AddMatrixTableByRows(self._handler, c_data,
|
||||
mv_lib.MV_AddMatrixTableByRows(self._handler, data.ctypes.data_as(C_FLOAT_P),
|
||||
row_ids_n * self._num_col,
|
||||
int_array_type(*row_ids), row_ids_n)
|
||||
|
|
|
@ -50,12 +50,10 @@ class MVNetParamManager(object):
|
|||
1) calc all the delta of params in the network and add the delta to multiverso server
|
||||
2) get the latest value from the multiverso server
|
||||
'''
|
||||
latest_network_params = []
|
||||
for arr in lasagne.layers.get_all_param_values(self.network):
|
||||
latest_network_params.extend([i for i in np.nditer(arr)])
|
||||
latest_network_params = np.array(latest_network_params)
|
||||
cur_network_params = np.concatenate([
|
||||
arr.reshape(-1) for arr in lasagne.layers.get_all_param_values(self.network)])
|
||||
|
||||
params_delta = latest_network_params - self.all_param_list
|
||||
params_delta = cur_network_params - self.all_param_list
|
||||
self.tbh.add(params_delta)
|
||||
self.all_param_list = self.tbh.get()
|
||||
self._set_all_param_to_net()
|
||||
|
|
Загрузка…
Ссылка в новой задаче