The checker accepts --source-column-map and --source-table-name.

This commit is contained in:
Ric Szopa 2013-04-25 17:55:44 -07:00
Родитель c7fda52899
Коммит 979168c375
2 изменённых файлов: 92 добавлений и 15 удалений

Просмотреть файл

@ -299,15 +299,29 @@ class Stats(object):
class Checker(object):
def __init__(self, destination_url, sources_urls, table, directory='.',
source_column_map=None, source_table_name=None,
keyrange={}, batch_count=0, blocks=1, ratio=1.0, block_size=16384,
logging_level=logging.INFO, stats_interval=1, temp_directory=None):
self.table_name = table
self.table_data = self.get_table_data(table, parse_database_url(sources_urls[0]))
if source_table_name is None:
self.source_table_name = self.table_name
else:
self.source_table_name = source_table_name
self.table_data = self.get_table_data(table, parse_database_url(destination_url))
self.primary_key = self.table_data['pk']
if source_column_map:
self.source_column_map = source_column_map
else:
self.source_column_map = {}
columns = self.table_data['columns']
for k in self.primary_key:
columns.remove(k)
self.columns = self.primary_key + columns
self.source_columns = [self.source_column_map.get(c, c) for c in self.columns]
self.source_primary_key = [self.source_column_map.get(c, c) for c in self.primary_key]
self.pk_length = len(self.primary_key)
(self.batch_count, self.block_size,
self.ratio, self.blocks) = batch_count, block_size, ratio, blocks
@ -353,11 +367,11 @@ class Checker(object):
where %(keyspace_sql)s
(%(range_sql)s)
order by %(pk_columns)s limit %%(limit)s""" % {
'table_name': self.table_name,
'table_name': self.source_table_name,
'keyspace_sql': ' '.join(keyspace_sql_parts),
'columns': ', '.join(self.columns),
'pk_columns': ', '.join(self.primary_key),
'range_sql': sql_tuple_comparison(self.table_name, self.primary_key)}
'columns': ', '.join(self.source_columns),
'pk_columns': ', '.join(self.source_primary_key),
'range_sql': sql_tuple_comparison(self.source_table_name, self.source_primary_key)}
self.source_sql = """
select
@ -366,12 +380,12 @@ class Checker(object):
where %(keyspace_sql)s
((%(min_range_sql)s) and not (%(max_range_sql)s))
order by %(pk_columns)s""" % {
'table_name': self.table_name,
'table_name': self.source_table_name,
'keyspace_sql': ' '.join(keyspace_sql_parts),
'columns': ', '.join(self.columns),
'pk_columns': ', '.join(self.primary_key),
'min_range_sql': sql_tuple_comparison(self.table_name, self.primary_key),
'max_range_sql': sql_tuple_comparison(self.table_name, self.primary_key, column_name_prefix='max_')}
'columns': ', '.join(self.source_columns),
'pk_columns': ', '.join(self.source_primary_key),
'min_range_sql': sql_tuple_comparison(self.source_table_name, self.source_primary_key),
'max_range_sql': sql_tuple_comparison(self.source_table_name, self.source_primary_key, column_name_prefix='max_')}
self.stats = Stats(interval=stats_interval, name=self.table_name)
self.destination = Datastore(parse_database_url(destination_url), stats=self.stats)
@ -493,11 +507,11 @@ class Checker(object):
return
# query the sources -> sources_data
params = dict(start_pk)
params = dict((self.source_column_map.get(k, k), v) for k, v in start_pk.items())
if destination_data:
for k, v in end_pk.items():
params['max_' + k] = v
params['max_' + self.source_column_map.get(k, k)] = v
sources_data = self.sources.query(self.source_sql, params)
else:
params['limit'] = self.batch_size
@ -588,6 +602,15 @@ def main():
parser.add_option('-r', '--ratio', dest='ratio',
type='float', default=1.0,
help='Assumed block fill ratio.')
parser.add_option('--source-column-map',
dest='source_column_map', type='string',
help='column_in_destination:column_in_source,column_in_destination2:column_in_source2,...',
default='')
parser.add_option('--source-table-name',
dest='source_table_name', type='string',
help='name of the table in sources (if different than in destination)',
default=None)
parser.add_option('-b', '--blocks', dest='blocks',
type='float', default=3,
help='Try to send this many blocks in one commit.')
@ -600,9 +623,18 @@ def main():
help="keyrange end (hexadecimal)")
(options, args) = parser.parse_args()
table, destination, sources = args[0], args[1], args[2:]
source_column_map = {}
if options.source_column_map:
for pair in options.source_column_map.split(','):
k, v = pair.split(':')
source_column_map[k] = v
checker = Checker(destination, sources, table, options.checkpoint_directory,
source_column_map=source_column_map,
source_table_name=options.source_table_name,
keyrange=get_range(options.start, options.end),
stats_interval=options.stats, batch_count=options.batch_count,
block_size=options.block_size, ratio=options.ratio,

Просмотреть файл

@ -59,8 +59,8 @@ class MockChecker(checker.Checker):
class TestCheckersBase(unittest.TestCase):
keyrange = {"end": 900}
def make_checker(self, **kwargs):
default = {'keyrange': TestCheckers.keyrange,
def make_checker(self, destination_table_name="test", **kwargs):
default = {'keyrange': TestCheckersBase.keyrange,
'batch_count': 20,
'logging_level': logging.WARNING,
'directory': tempfile.mkdtemp()}
@ -68,7 +68,7 @@ class TestCheckersBase(unittest.TestCase):
source_addresses = ['vt_dba@localhost:%s/test_checkers%s?unix_socket=%s' % (s.mysql_port, i, s.mysql_connection_parameters('test_checkers')['unix_socket'])
for i, s in enumerate(source_tablets)]
destination_socket = destination_tablet.mysql_connection_parameters('test_checkers')['unix_socket']
return MockChecker('vt_dba@localhost/test_checkers?unix_socket=%s' % destination_socket, source_addresses, 'test', **default)
return MockChecker('vt_dba@localhost/test_checkers?unix_socket=%s' % destination_socket, source_addresses, destination_table_name, **default)
class TestSortedRowListDifference(unittest.TestCase):
@ -182,6 +182,51 @@ class TestDifferentEncoding(TestCheckersBase):
self.c._run()
self.assertTrue(self.c.mismatches)
class TestRlookup(TestCheckersBase):
def setUp(self):
source_create_table = "create table test (pk1 bigint, k2 bigint, k3 bigint, keyspace_id bigint, msg varchar(64), primary key (pk1)) Engine=InnoDB"
destination_create_table = "create table test_lookup (pk1_lookup bigint, msg_lookup varchar(64), primary key (pk1_lookup)) Engine=InnoDB"
destination_tablet.create_db("test_checkers")
destination_tablet.mquery("test_checkers", destination_create_table, True)
for i, t in enumerate(source_tablets):
t.create_db("test_checkers%s" % i)
t.mquery("test_checkers%s" % i, source_create_table, True)
destination_queries = []
source_queries = [[] for t in source_tablets]
for i in range(1, 400):
destination_queries.append("insert into test_lookup (pk1_lookup, msg_lookup) values (%s, 'message %s')" % (i, i))
source_queries[i % 2].append("insert into test (pk1, k2, k3, msg, keyspace_id) values (%s, %s, %s, 'message %s', %s)" % (i, i, i, i, i))
for i in range(1100, 1110):
query = "insert into test (pk1, k2, k3, msg, keyspace_id) values (%s, %s, %s, 'message %s', %s)" % (i, i, i, i, i)
source_queries[0].append(query)
destination_tablet.mquery("test_checkers", destination_queries, write=True)
for i, (tablet, queries) in enumerate(zip(source_tablets, source_queries)):
tablet.mquery("test_checkers%s" % i, queries, write=True)
self.c = self.make_checker(destination_table_name="test_lookup", source_table_name="test", source_column_map={'pk1_lookup': 'pk1', 'msg_lookup': 'msg'})
def tearDown(self):
destination_tablet.mquery("test_checkers", "drop table test_lookup", True)
for i, t in enumerate(source_tablets):
t.mquery("test_checkers%s" % i, "drop table test", True)
def test_ok(self):
self.c._run()
self.assertFalse(self.c.mismatches)
def test_different_value(self):
destination_tablet.mquery("test_checkers", "update test_lookup set msg_lookup='something else' where pk1_lookup = 29", write=True)
self.c._run()
self.assertTrue(self.c.mismatches)
def test_additional_value(self):
destination_tablet.mquery("test_checkers", "insert into test_lookup (pk1_lookup, msg_lookup) values (11000, 'something new')", write=True)
self.c._run()
self.assertTrue(self.c.mismatches)
def main():
parser = optparse.OptionParser(usage="usage: %prog [options] [test_names]")