Checker: reading passwords from optional dr-credentials file.

LGTM Ric
This commit is contained in:
Alain Jobart 2013-04-29 10:28:44 -07:00
Родитель b6d17628fc
Коммит 3899c4fe7c
1 изменённых файлов: 20 добавлений и 7 удалений

Просмотреть файл

@ -5,6 +5,7 @@ import datetime
import difflib
import heapq
import itertools
import json
import logging
import optparse
import os
@ -33,7 +34,7 @@ def merge_sorted(seqs, key=None):
else:
return (i[1] for i in heapq.merge(*(((key(item), item) for item in seq) for seq in seqs)))
def parse_database_url(url):
def parse_database_url(url, password_map):
if not url.startswith('mysql'):
url = 'mysql://' + url
url = 'http' + url[len('mysql'):]
@ -43,8 +44,10 @@ def parse_database_url(url):
'db': parsed.path[1:]}
if parsed.username:
params['user'] = parsed.username
if parsed.password is not None:
if parsed.password:
params['passwd'] = parsed.password
elif parsed.username and parsed.username in password_map:
params['passwd'] = password_map[parsed.username][0]
if parsed.port:
params['port'] = parsed.port
params.update(dict(urlparse.parse_qsl(parsed.query)))
@ -302,13 +305,20 @@ class Checker(object):
source_column_map=None, source_table_name=None, source_force_index_pk=True,
destination_force_index_pk=True,
keyrange={}, batch_count=0, blocks=1, ratio=1.0, block_size=16384,
logging_level=logging.INFO, stats_interval=1, temp_directory=None):
logging_level=logging.INFO, stats_interval=1, temp_directory=None, password_map_file=None):
self.table_name = table
if source_table_name is None:
self.source_table_name = self.table_name
else:
self.source_table_name = source_table_name
self.table_data = self.get_table_data(table, parse_database_url(destination_url))
if password_map_file:
with open(password_map_file, "r") as f:
password_map = json.load(f)
else:
password_map = {}
self.table_data = self.get_table_data(table, parse_database_url(destination_url, password_map))
self.primary_key = self.table_data['pk']
if source_column_map:
@ -400,8 +410,8 @@ class Checker(object):
'max_range_sql': sql_tuple_comparison(self.source_table_name, self.source_primary_key, column_name_prefix='max_')}
self.stats = Stats(interval=stats_interval, name=self.table_name)
self.destination = Datastore(parse_database_url(destination_url), stats=self.stats)
self.sources = MultiDatastore([parse_database_url(s) for s in sources_urls], 'all-sources', stats=self.stats)
self.destination = Datastore(parse_database_url(destination_url, password_map), stats=self.stats)
self.sources = MultiDatastore([parse_database_url(s, password_map) for s in sources_urls], 'all-sources', stats=self.stats)
logging.basicConfig(level=logging_level)
logging.debug("destination sql template: %s", clean(self.destination_sql))
@ -638,6 +648,8 @@ def main():
help="keyrange start (hexadecimal)")
parser.add_option('--end', type='string', dest='end', default='',
help="keyrange end (hexadecimal)")
parser.add_option('--password-map-file', type='string', default=None,
help="password map file")
(options, args) = parser.parse_args()
table, destination, sources = args[0], args[1], args[2:]
@ -656,7 +668,8 @@ def main():
keyrange=get_range(options.start, options.end),
stats_interval=options.stats, batch_count=options.batch_count,
block_size=options.block_size, ratio=options.ratio,
temp_directory=options.checkpoint_directory)
temp_directory=options.checkpoint_directory,
password_map_file=options.password_map_file)
checker.run()
if __name__ == '__main__':