This commit is contained in:
Liang Guo 2014-06-02 16:00:50 -07:00
Родитель f5d4955df0
Коммит 215e3e0813
2 изменённых файлов: 12 добавлений и 48 удалений

Просмотреть файл

@ -53,43 +53,20 @@ class StreamingTaskMap(object):
min_key_hex = int("00", base=16)
max_key_hex = int("100", base=16)
kr = min_key_hex
if self.num_tasks >= self.shard_count:
span = (max_key_hex - min_key_hex)/self.num_tasks
kr_chunks.append('')
for i in xrange(self.num_tasks):
kr += span
#kr_chunks.append(hex(kr).split('0x')[1])
kr_chunks.append('%x' % kr)
kr_chunks[-1] = ''
self.keyrange_list = [[str(KeyRange((kr_chunks[i], kr_chunks[i+1],)))] for i in xrange(len(kr_chunks) - 1)]
else:
span = (max_key_hex - min_key_hex)/self.shard_count
kr_chunks.append('')
for i in xrange(self.shard_count):
kr += span
kr_chunks.append('%x' % kr)
kr_chunks[-1] = ''
kr_list = [[str(KeyRange((kr_chunks[i], kr_chunks[i+1],)))] for i in xrange(len(kr_chunks) - 1)]
item_per_task = len(kr_list)/self.num_tasks
j = 0
item = []
for i in xrange(len(kr_list)):
if j < item_per_task:
item.extend(kr_list[i])
j = j+1
if j == item_per_task:
self.keyrange_list.append(item)
j = 0
item = []
return self.keyrange_list
span = (max_key_hex - min_key_hex)/self.num_tasks
kr_chunks.append('')
for i in xrange(self.num_tasks):
kr += span
#kr_chunks.append(hex(kr).split('0x')[1])
kr_chunks.append('%x' % kr)
kr_chunks[-1] = ''
self.keyrange_list = [str(KeyRange((kr_chunks[i], kr_chunks[i+1],))) for i in xrange(len(kr_chunks) - 1)]
# Compute the task map for a streaming query.
# shard_count is read from config, using it as a param for simplicity.
def create_streaming_task_map(num_tasks, shard_count):
def _is_power2(num):
return num != 0 and ((num & (num - 1)) == 0)
if not _is_power2(num_tasks) or not _is_power2(shard_count):
raise dbexceptions.ProgrammingError('tasks %d and shard_count %d should be power of 2'
if num_tasks % shard_count != 0:
raise dbexceptions.ProgrammingError('tasks %d should be multiple of shard_count %d'
% (num_tasks, shard_count))
return StreamingTaskMap(num_tasks, shard_count)

Просмотреть файл

@ -39,37 +39,25 @@ class TestKeyRange(unittest.TestCase):
def test_incorrect_tasks(self):
global_shard_count = 16
with self.assertRaises(dbexceptions.ProgrammingError):
stm = keyrange.create_streaming_task_map(9, global_shard_count)
stm = keyrange.create_streaming_task_map(4, global_shard_count)
def test_keyranges_for_tasks(self):
for global_shard_count in (16,32,64):
num_tasks = global_shard_count
stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count)
self.assertEqual(len(stm.keyrange_list), num_tasks)
for i in xrange(num_tasks):
self.assertEqual(len(stm.keyrange_list[i]), 1)
num_tasks = global_shard_count*2
stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count)
self.assertEqual(len(stm.keyrange_list), num_tasks)
for i in xrange(num_tasks):
self.assertEqual(len(stm.keyrange_list[i]), 1)
num_tasks = global_shard_count*8
stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count)
self.assertEqual(len(stm.keyrange_list), num_tasks)
for i in xrange(num_tasks):
self.assertEqual(len(stm.keyrange_list[i]), 1)
num_tasks = global_shard_count/2
stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count)
self.assertEqual(len(stm.keyrange_list), num_tasks)
for i in xrange(num_tasks):
self.assertEqual(len(stm.keyrange_list[i]), 2)
# This tests that the where clause and bind_vars generated for each shard
# against a few sample values where keyspace_id is an int column.
def test_bind_values_for_int_keyspace(self):
stm = keyrange.create_streaming_task_map(16, 16)
for i, kr in enumerate(stm.keyrange_list):
kr = kr[0]
kr_parts = kr.split('-')
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(kr)
if len(bind_vars.keys()) == 1:
@ -101,7 +89,6 @@ class TestKeyRange(unittest.TestCase):
def test_bind_values_for_str_keyspace(self):
stm = keyrange.create_streaming_task_map(16, 16)
for i, kr in enumerate(stm.keyrange_list):
kr = kr[0]
kr_parts = kr.split('-')
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(kr, keyspace_col_type=keyrange_constants.KIT_BYTES)
if len(bind_vars.keys()) == 1:
@ -127,7 +114,7 @@ class TestKeyRange(unittest.TestCase):
def test_bind_values_for_unsharded_keyspace(self):
stm = keyrange.create_streaming_task_map(1, 1)
self.assertEqual(len(stm.keyrange_list), 1)
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(stm.keyrange_list[0][0])
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(stm.keyrange_list[0])
self.assertEqual(where_clause, "")
self.assertEqual(bind_vars, {})