From 215e3e0813511ca8385d82e2e1e7a77774863b5c Mon Sep 17 00:00:00 2001 From: Liang Guo Date: Mon, 2 Jun 2014 16:00:50 -0700 Subject: [PATCH] Update checkpoing api --- py/vtdb/keyrange.py | 43 ++++++++++--------------------------------- test/keyrange_test.py | 17 ++--------------- 2 files changed, 12 insertions(+), 48 deletions(-) diff --git a/py/vtdb/keyrange.py b/py/vtdb/keyrange.py index d407343576..fc2c24d8e9 100644 --- a/py/vtdb/keyrange.py +++ b/py/vtdb/keyrange.py @@ -53,43 +53,20 @@ class StreamingTaskMap(object): min_key_hex = int("00", base=16) max_key_hex = int("100", base=16) kr = min_key_hex - if self.num_tasks >= self.shard_count: - span = (max_key_hex - min_key_hex)/self.num_tasks - kr_chunks.append('') - for i in xrange(self.num_tasks): - kr += span - #kr_chunks.append(hex(kr).split('0x')[1]) - kr_chunks.append('%x' % kr) - kr_chunks[-1] = '' - self.keyrange_list = [[str(KeyRange((kr_chunks[i], kr_chunks[i+1],)))] for i in xrange(len(kr_chunks) - 1)] - else: - span = (max_key_hex - min_key_hex)/self.shard_count - kr_chunks.append('') - for i in xrange(self.shard_count): - kr += span - kr_chunks.append('%x' % kr) - kr_chunks[-1] = '' - kr_list = [[str(KeyRange((kr_chunks[i], kr_chunks[i+1],)))] for i in xrange(len(kr_chunks) - 1)] - item_per_task = len(kr_list)/self.num_tasks - j = 0 - item = [] - for i in xrange(len(kr_list)): - if j < item_per_task: - item.extend(kr_list[i]) - j = j+1 - if j == item_per_task: - self.keyrange_list.append(item) - j = 0 - item = [] - return self.keyrange_list + span = (max_key_hex - min_key_hex)/self.num_tasks + kr_chunks.append('') + for i in xrange(self.num_tasks): + kr += span + #kr_chunks.append(hex(kr).split('0x')[1]) + kr_chunks.append('%x' % kr) + kr_chunks[-1] = '' + self.keyrange_list = [str(KeyRange((kr_chunks[i], kr_chunks[i+1],))) for i in xrange(len(kr_chunks) - 1)] # Compute the task map for a streaming query. # shard_count is read from config, using it as a param for simplicity. def create_streaming_task_map(num_tasks, shard_count): - def _is_power2(num): - return num != 0 and ((num & (num - 1)) == 0) - if not _is_power2(num_tasks) or not _is_power2(shard_count): - raise dbexceptions.ProgrammingError('tasks %d and shard_count %d should be power of 2' + if num_tasks % shard_count != 0: + raise dbexceptions.ProgrammingError('tasks %d should be multiple of shard_count %d' % (num_tasks, shard_count)) return StreamingTaskMap(num_tasks, shard_count) diff --git a/test/keyrange_test.py b/test/keyrange_test.py index 94522825eb..e5abe6fcf8 100755 --- a/test/keyrange_test.py +++ b/test/keyrange_test.py @@ -39,37 +39,25 @@ class TestKeyRange(unittest.TestCase): def test_incorrect_tasks(self): global_shard_count = 16 with self.assertRaises(dbexceptions.ProgrammingError): - stm = keyrange.create_streaming_task_map(9, global_shard_count) + stm = keyrange.create_streaming_task_map(4, global_shard_count) def test_keyranges_for_tasks(self): for global_shard_count in (16,32,64): num_tasks = global_shard_count stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count) self.assertEqual(len(stm.keyrange_list), num_tasks) - for i in xrange(num_tasks): - self.assertEqual(len(stm.keyrange_list[i]), 1) num_tasks = global_shard_count*2 stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count) self.assertEqual(len(stm.keyrange_list), num_tasks) - for i in xrange(num_tasks): - self.assertEqual(len(stm.keyrange_list[i]), 1) num_tasks = global_shard_count*8 stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count) self.assertEqual(len(stm.keyrange_list), num_tasks) - for i in xrange(num_tasks): - self.assertEqual(len(stm.keyrange_list[i]), 1) - num_tasks = global_shard_count/2 - stm = keyrange.create_streaming_task_map(num_tasks, global_shard_count) - self.assertEqual(len(stm.keyrange_list), num_tasks) - for i in xrange(num_tasks): - self.assertEqual(len(stm.keyrange_list[i]), 2) # This tests that the where clause and bind_vars generated for each shard # against a few sample values where keyspace_id is an int column. def test_bind_values_for_int_keyspace(self): stm = keyrange.create_streaming_task_map(16, 16) for i, kr in enumerate(stm.keyrange_list): - kr = kr[0] kr_parts = kr.split('-') where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(kr) if len(bind_vars.keys()) == 1: @@ -101,7 +89,6 @@ class TestKeyRange(unittest.TestCase): def test_bind_values_for_str_keyspace(self): stm = keyrange.create_streaming_task_map(16, 16) for i, kr in enumerate(stm.keyrange_list): - kr = kr[0] kr_parts = kr.split('-') where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(kr, keyspace_col_type=keyrange_constants.KIT_BYTES) if len(bind_vars.keys()) == 1: @@ -127,7 +114,7 @@ class TestKeyRange(unittest.TestCase): def test_bind_values_for_unsharded_keyspace(self): stm = keyrange.create_streaming_task_map(1, 1) self.assertEqual(len(stm.keyrange_list), 1) - where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(stm.keyrange_list[0][0]) + where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(stm.keyrange_list[0]) self.assertEqual(where_clause, "") self.assertEqual(bind_vars, {})