[python-binding] Update/Unittest for new matrix table get.

This commit is contained in:
huxuan 2016-05-24 14:02:35 +09:00
Родитель 4bc8382171
Коммит 2429f662c5
2 изменённых файлов: 14 добавлений и 10 удалений

Просмотреть файл

@ -99,14 +99,11 @@ class MatrixTableHandler(TableHandler):
else:
row_ids_n = len(row_ids)
int_array_type = c_int * row_ids_n
float_array_array_type = c_float * self._num_col * row_ids_n
float_pointer_array_type = POINTER(c_float) * row_ids_n
array_data = float_array_array_type()
c_data = float_pointer_array_type(*[row for row in array_data])
mv_lib.MV_GetMatrixTableByRows(self._handler, int_array_type(*row_ids),
row_ids_n, self._num_col, c_data)
return np.array(self._construct_matrix(array_data)).reshape((row_ids_n, self._num_col))
float_array_type = c_float * (row_ids_n * self._num_col)
c_data = float_array_type()
mv_lib.MV_GetMatrixTableByRows(self._handler, c_data, self._num_col,
int_array_type(*row_ids), row_ids_n)
return np.array(self._construct_matrix(c_data)).reshape((row_ids_n, self._num_col))
def add(self, data=None, row_ids=None):
'''

Просмотреть файл

@ -22,6 +22,7 @@ def TestMatrix():
num_row = 11
num_col = 10
size = num_col * num_row
workers_num = mv.workers_num()
tbh = mv.MatrixTableHandler(num_row, num_col)
mv.barrier()
for count in xrange(1, 20):
@ -33,9 +34,15 @@ def TestMatrix():
mv.barrier()
for i, row in enumerate(data):
for j, actual in enumerate(row):
expected = (i * num_col + j) * count * mv.workers_num()
expected = (i * num_col + j) * count * workers_num
if i in row_ids:
expected += (i * num_col + j) * count * mv.workers_num()
expected += (i * num_col + j) * count * workers_num
assert(expected == actual)
data = tbh.get(row_ids)
mv.barrier()
for i, row in enumerate(data):
for j, actual in enumerate(row):
expected = (row_ids[i] * num_col + j) * count * workers_num * 2
assert(expected == actual)
mv.shutdown()