From 2429f662c504cf182283b8687ab9aefc133859e3 Mon Sep 17 00:00:00 2001 From: huxuan Date: Tue, 24 May 2016 14:02:35 +0900 Subject: [PATCH] [python-binding] Update/Unittest for new matrix table get. --- binding/python/multiverso/api.py | 13 +++++-------- binding/python/multiverso/test.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/binding/python/multiverso/api.py b/binding/python/multiverso/api.py index 9170858..89a2411 100644 --- a/binding/python/multiverso/api.py +++ b/binding/python/multiverso/api.py @@ -99,14 +99,11 @@ class MatrixTableHandler(TableHandler): else: row_ids_n = len(row_ids) int_array_type = c_int * row_ids_n - float_array_array_type = c_float * self._num_col * row_ids_n - float_pointer_array_type = POINTER(c_float) * row_ids_n - - array_data = float_array_array_type() - c_data = float_pointer_array_type(*[row for row in array_data]) - mv_lib.MV_GetMatrixTableByRows(self._handler, int_array_type(*row_ids), - row_ids_n, self._num_col, c_data) - return np.array(self._construct_matrix(array_data)).reshape((row_ids_n, self._num_col)) + float_array_type = c_float * (row_ids_n * self._num_col) + c_data = float_array_type() + mv_lib.MV_GetMatrixTableByRows(self._handler, c_data, self._num_col, + int_array_type(*row_ids), row_ids_n) + return np.array(self._construct_matrix(c_data)).reshape((row_ids_n, self._num_col)) def add(self, data=None, row_ids=None): ''' diff --git a/binding/python/multiverso/test.py b/binding/python/multiverso/test.py index fec9adc..4d2b802 100644 --- a/binding/python/multiverso/test.py +++ b/binding/python/multiverso/test.py @@ -22,6 +22,7 @@ def TestMatrix(): num_row = 11 num_col = 10 size = num_col * num_row + workers_num = mv.workers_num() tbh = mv.MatrixTableHandler(num_row, num_col) mv.barrier() for count in xrange(1, 20): @@ -33,9 +34,15 @@ def TestMatrix(): mv.barrier() for i, row in enumerate(data): for j, actual in enumerate(row): - expected = (i * num_col + j) * count * mv.workers_num() + expected = (i * num_col + j) * count * workers_num if i in row_ids: - expected += (i * num_col + j) * count * mv.workers_num() + expected += (i * num_col + j) * count * workers_num + assert(expected == actual) + data = tbh.get(row_ids) + mv.barrier() + for i, row in enumerate(data): + for j, actual in enumerate(row): + expected = (row_ids[i] * num_col + j) * count * workers_num * 2 assert(expected == actual) mv.shutdown()