diff --git a/binding/python/multiverso/api.py b/binding/python/multiverso/api.py index d6fb22e..9753b72 100644 --- a/binding/python/multiverso/api.py +++ b/binding/python/multiverso/api.py @@ -46,6 +46,10 @@ class TableHandler(object): raise NotImplementedError("You must implement the add method.") +# types +C_FLOAT_P = POINTER(c_float) + + class ArrayTableHandler(TableHandler): def __init__(self, size): self._handler = c_void_p() @@ -56,17 +60,19 @@ class ArrayTableHandler(TableHandler): ''' Data type of return value is numpy.ndarray ''' - c_data = (c_float * self._size)() - mv_lib.MV_GetArrayTable(self._handler, c_data, self._size) - return np.array(list(c_data), dtype=np.dtype("float32")) + data = np.zeros((self._size, ), dtype=np.dtype("float32")) + mv_lib.MV_GetArrayTable(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size) + return data def add(self, data): ''' Data type of `data` is numpy.ndarray ''' - data = np.array(data) + if not isinstance(data, np.ndarray): + data = np.array(data) assert(data.size == self._size) - mv_lib.MV_AddArrayTable(self._handler, (c_float * self._size)(*data.reshape((-1,))), self._size) + data = data.astype(np.float32) + mv_lib.MV_AddArrayTable(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size) class MatrixTableHandler(TableHandler): @@ -84,19 +90,17 @@ class MatrixTableHandler(TableHandler): to the row_ids ''' if row_ids is None: - float_array_type = c_float * (self._num_row * self._num_col) - c_data = float_array_type() - mv_lib.MV_GetMatrixTableAll(self._handler, c_data, self._size) - return np.array(list(c_data), dtype=np.dtype("float32")).reshape((self._num_row, self._num_col)) + data = np.zeros((self._num_row, self._num_col), dtype=np.dtype("float32")) + mv_lib.MV_GetMatrixTableAll(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size) + return data else: row_ids_n = len(row_ids) int_array_type = c_int * row_ids_n - float_array_type = c_float * (row_ids_n * self._num_col) - c_data = float_array_type() - mv_lib.MV_GetMatrixTableByRows(self._handler, c_data, + data = np.zeros((row_ids_n, self._num_col), dtype=np.dtype("float32")) + mv_lib.MV_GetMatrixTableByRows(self._handler, data.ctypes.data_as(C_FLOAT_P), row_ids_n * self._num_col, int_array_type(*row_ids), row_ids_n) - return np.array(list(c_data), dtype=np.dtype("float32")).reshape((row_ids_n, self._num_col)) + return data def add(self, data=None, row_ids=None): ''' @@ -106,19 +110,17 @@ class MatrixTableHandler(TableHandler): Otherwise we will add the data according to the row_ids ''' assert(data is not None) - data = np.array(data) + if not isinstance(data, np.ndarray): + data = np.array(data) + data = data.astype(np.float32) + if row_ids is None: assert(data.size == self._size) - float_array_type = c_float * (self._num_row * self._num_col) - c_data = float_array_type(* data.reshape((-1, ))) - mv_lib.MV_AddMatrixTableAll(self._handler, c_data, self._size) + mv_lib.MV_AddMatrixTableAll(self._handler, data.ctypes.data_as(C_FLOAT_P), self._size) else: row_ids_n = len(row_ids) assert(data.size == row_ids_n * self._num_col) int_array_type = c_int * row_ids_n - float_array_type = c_float * (self._num_col * row_ids_n) - - c_data = float_array_type(* data.reshape((-1, ))) - mv_lib.MV_AddMatrixTableByRows(self._handler, c_data, + mv_lib.MV_AddMatrixTableByRows(self._handler, data.ctypes.data_as(C_FLOAT_P), row_ids_n * self._num_col, int_array_type(*row_ids), row_ids_n) diff --git a/binding/python/multiverso/theano_ext/lasagne_ext/param_manager.py b/binding/python/multiverso/theano_ext/lasagne_ext/param_manager.py index 787cda2..7926e74 100644 --- a/binding/python/multiverso/theano_ext/lasagne_ext/param_manager.py +++ b/binding/python/multiverso/theano_ext/lasagne_ext/param_manager.py @@ -50,12 +50,10 @@ class MVNetParamManager(object): 1) calc all the delta of params in the network and add the delta to multiverso server 2) get the latest value from the multiverso server ''' - latest_network_params = [] - for arr in lasagne.layers.get_all_param_values(self.network): - latest_network_params.extend([i for i in np.nditer(arr)]) - latest_network_params = np.array(latest_network_params) + cur_network_params = np.concatenate([ + arr.reshape(-1) for arr in lasagne.layers.get_all_param_values(self.network)]) - params_delta = latest_network_params - self.all_param_list + params_delta = cur_network_params - self.all_param_list self.tbh.add(params_delta) self.all_param_list = self.tbh.get() self._set_all_param_to_net()