vitess-gh/test/utils.py

1338 строки
42 KiB
Python

#!/usr/bin/env python
# Copyright 2019 The Vitess Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common import for all tests."""
import base64
import contextlib
import json
import logging
import optparse
import os
import shlex
import shutil
import signal
import socket
import subprocess
import sys
import time
import unittest
import urllib2
from vtdb import prefer_vtroot_imports # pylint: disable=unused-import
from vtdb import vtgate_client
import environment
from mysql_flavor import mysql_flavor
from mysql_flavor import set_mysql_flavor
import MySQLdb
from protocols_flavor import protocols_flavor
from topo_flavor.server import set_topo_server_flavor
from vtctl import vtctl_client
from vtdb import keyrange_constants
from vtgate_gateway_flavor.gateway import set_vtgate_gateway_flavor
from vtgate_gateway_flavor.gateway import vtgate_gateway_flavor
from vtproto import topodata_pb2
options = None
devnull = open('/dev/null', 'w')
try:
hostname = socket.getaddrinfo(
socket.getfqdn(), None, 0, 0, 0, socket.AI_CANONNAME)[0][3]
except socket.gaierror:
# Fallback to 'localhost' if getfqdn() returns this value for "::1" and
# getaddrinfo() cannot resolve it and throws an exception.
# This error scenario was observed on mberlin@'s corp Macbook in 2018.
if socket.getfqdn() == '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa': # pylint: disable=line-too-long
hostname = 'localhost'
else:
raise
class TestError(Exception):
pass
class Break(Exception):
pass
environment.setup()
class LoggingStream(object):
def __init__(self):
self.line = ''
def write(self, value):
if value == '\n':
# we already printed it
self.line = ''
return
self.line += value
logging.info('===== ' + self.line)
if value.endswith('\n'):
self.line = ''
def writeln(self, value):
self.write(value)
self.line = ''
def flush(self):
pass
def add_options(parser):
environment.add_options(parser)
parser.add_option('-d', '--debug', action='store_true',
help='utils.pause() statements will wait for user input')
parser.add_option('-k', '--keep-logs', action='store_true',
help='Do not delete log files on teardown.')
parser.add_option(
'-q', '--quiet', action='store_const', const=0, dest='verbose', default=1)
parser.add_option(
'-v', '--verbose', action='store_const', const=2, dest='verbose',
default=1)
parser.add_option('--skip-build', action='store_true',
help='Do not build the go binaries when running the test.')
parser.add_option(
'--skip-teardown', action='store_true',
help='Leave the global processes running after the test is done.')
parser.add_option('--mysql-flavor')
parser.add_option('--protocols-flavor', default='grpc')
parser.add_option('--topo-server-flavor', default='zk2')
parser.add_option('--vtgate-gateway-flavor', default='discoverygateway')
def set_options(opts):
global options
options = opts
set_mysql_flavor(options.mysql_flavor)
environment.setup_protocol_flavor(options.protocols_flavor)
set_topo_server_flavor(options.topo_server_flavor)
set_vtgate_gateway_flavor(options.vtgate_gateway_flavor)
environment.skip_build = options.skip_build
# main executes the test classes contained in the passed module, or
# __main__ if empty.
def main(mod=None, test_options=None):
"""The replacement main method, which parses args and runs tests.
Args:
mod: module that contains the test methods.
test_options: a function which adds OptionParser options that are specific
to a test file.
"""
if mod is None:
mod = sys.modules['__main__']
global options
parser = optparse.OptionParser(usage='usage: %prog [options] [test_names]')
add_options(parser)
if test_options:
test_options(parser)
(options, args) = parser.parse_args()
environment.set_log_level(options.verbose)
logging.basicConfig(
format='-- %(asctime)s %(module)s:%(lineno)d %(levelname)s %(message)s')
set_options(options)
run_tests(mod, args)
def run_tests(mod, args):
try:
suite = unittest.TestSuite()
if not args:
# this will run the setup and teardown
suite.addTests(unittest.TestLoader().loadTestsFromModule(mod))
else:
if args[0] == 'teardown':
mod.tearDownModule()
elif args[0] == 'setup':
mod.setUpModule()
else:
for arg in args:
# this will run the setup and teardown
suite.addTests(unittest.TestLoader().loadTestsFromName(arg, mod))
if suite.countTestCases() > 0:
logger = LoggingStream()
result = unittest.TextTestRunner(
stream=logger, verbosity=options.verbose, failfast=True).run(suite)
if not result.wasSuccessful():
sys.exit(-1)
except KeyboardInterrupt:
logging.warning('======== Tests interrupted, cleaning up ========')
mod.tearDownModule()
# If you interrupt a test, you probably want to stop evaluating the rest.
sys.exit(1)
finally:
if options.keep_logs:
logging.warning('Leaving temporary files behind (--keep-logs), please '
'clean up before next run: ' + os.environ['VTDATAROOT'])
def remove_tmp_files():
if options.keep_logs:
return
try:
shutil.rmtree(environment.tmproot)
except OSError as e:
logging.debug('remove_tmp_files: %s', str(e))
def pause(prompt):
if options.debug:
raw_input(prompt)
# sub-process management
pid_map = {}
already_killed = []
def _add_proc(proc):
pid_map[proc.pid] = proc
with open(environment.tmproot+'/test-pids', 'a') as f:
print >> f, proc.pid, os.path.basename(proc.args[0])
def required_teardown():
"""Required cleanup steps that can't be skipped with --skip-teardown."""
# We can't skip closing of gRPC connections, because the Python interpreter
# won't let us die if any connections are left open.
global vtctld_connection, vtctld
if vtctld_connection:
vtctld_connection.close()
vtctld_connection = None
vtctld = None
def kill_sub_processes():
for proc in pid_map.values():
if proc.pid and proc.returncode is None:
proc.kill()
if not os.path.exists(environment.tmproot+'/test-pids'):
return
with open(environment.tmproot+'/test-pids') as f:
for line in f:
try:
parts = line.strip().split()
pid = int(parts[0])
proc = pid_map.get(pid)
if not proc or (proc and proc.pid and proc.returncode is None):
if pid not in already_killed:
os.kill(pid, signal.SIGTERM)
except OSError as e:
logging.debug('kill_sub_processes: %s', str(e))
def kill_sub_process(proc, soft=False):
if proc is None:
return
pid = proc.pid
if soft:
proc.terminate()
else:
proc.kill()
if pid and pid in pid_map:
del pid_map[pid]
already_killed.append(pid)
# run in foreground, possibly capturing output
def run(cmd, trap_output=False, raise_on_error=True, **kargs):
if isinstance(cmd, str):
args = shlex.split(cmd)
else:
args = cmd
if trap_output:
kargs['stdout'] = subprocess.PIPE
kargs['stderr'] = subprocess.PIPE
logging.debug(
'run: %s %s', str(cmd),
', '.join('%s=%s' % x for x in kargs.iteritems()))
proc = subprocess.Popen(args, **kargs)
proc.args = args
stdout, stderr = proc.communicate()
if proc.returncode:
if raise_on_error:
pause('cmd fail: %s, pausing...' % (args))
raise TestError('cmd fail:', args, proc.returncode, stdout, stderr)
else:
logging.debug('cmd fail: %s %d %s %s',
str(args), proc.returncode, stdout, stderr)
return stdout, stderr
# run sub-process, expects failure
def run_fail(cmd, **kargs):
if isinstance(cmd, str):
args = shlex.split(cmd)
else:
args = cmd
kargs['stdout'] = subprocess.PIPE
kargs['stderr'] = subprocess.PIPE
if options.verbose == 2:
logging.debug(
'run: (expect fail) %s %s', cmd,
', '.join('%s=%s' % x for x in kargs.iteritems()))
proc = subprocess.Popen(args, **kargs)
proc.args = args
stdout, stderr = proc.communicate()
if proc.returncode == 0 or options.verbose == 2:
logging.info('stdout:\n%sstderr:\n%s', stdout, stderr)
if proc.returncode == 0:
raise TestError('expected fail:', args, stdout, stderr)
return stdout, stderr
# run a daemon - kill when this script exits
def run_bg(cmd, **kargs):
if options.verbose == 2:
logging.debug(
'run: %s %s', cmd, ', '.join('%s=%s' % x for x in kargs.iteritems()))
if 'extra_env' in kargs:
kargs['env'] = os.environ.copy()
if kargs['extra_env']:
kargs['env'].update(kargs['extra_env'])
del kargs['extra_env']
if isinstance(cmd, str):
args = shlex.split(cmd)
else:
args = cmd
proc = subprocess.Popen(args=args, **kargs)
proc.args = args
_add_proc(proc)
return proc
def wait_procs(proc_list, raise_on_error=True):
for proc in proc_list:
pid = proc.pid
if pid:
already_killed.append(pid)
for proc in proc_list:
proc.wait()
for proc in proc_list:
if proc.returncode:
if options.verbose >= 1 and proc.returncode not in (-9,):
sys.stderr.write('proc failed: %s %s\n' % (proc.returncode, proc.args))
if raise_on_error:
raise subprocess.CalledProcessError(proc.returncode,
' '.join(proc.args))
def validate_topology(ping_tablets=False):
if ping_tablets:
run_vtctl(['Validate', '-ping-tablets'])
else:
run_vtctl(['Validate'])
# wait_step is a helper for looping until a condition is true.
# use as follow:
# timeout = 10
# while True:
# <step>
# if <done>:
# break
# timeout = utils.wait_step('description of condition', timeout)
def wait_step(msg, timeout, sleep_time=0.1):
timeout -= sleep_time
if timeout <= 0:
raise TestError('timeout waiting for condition "%s"' % msg)
logging.debug('Sleeping for %f seconds waiting for condition "%s"',
sleep_time, msg)
time.sleep(sleep_time)
return timeout
# vars helpers
def get_vars(port):
"""Returns the dict for vars from a vtxxx process. None if not available."""
try:
url = 'http://localhost:%d/debug/vars' % int(port)
f = urllib2.urlopen(url)
data = f.read()
f.close()
except urllib2.URLError:
return None
try:
return json.loads(data)
except ValueError:
print data
raise
def wait_for_vars(name, port, var=None, key=None, value=None, timeout=10.0):
"""Waits for the vars of a process, and optional values.
Args:
name: nickname for the process.
port: process port to look at.
var: if specified, waits for var in vars.
key: if specified, waits for vars[var][key]==value.
value: if key if specified, waits for vars[var][key]==value.
timeout: how long to wait.
"""
text = 'waiting for http://localhost:%d/debug/vars of %s' % (port, name)
if var:
text += ' value %s' % var
if key:
text += ' key %s:%s' % (key, value)
while True:
display_text = text
v = get_vars(port)
if v:
if var is None:
break
if var in v:
if key is None:
break
if key in v[var]:
if v[var][key] == value:
break
else:
display_text += ' (current value:%s)' % v[var][key]
else:
display_text += ' (no current value)'
else:
display_text += ' (%s not in vars)' % var
else:
display_text += ' (no vars yet)'
timeout = wait_step(display_text, timeout)
def poll_for_vars(
name, port, condition_msg, timeout=60.0, condition_fn=None,
require_vars=False):
"""Polls for debug variables to exist or match specific conditions.
This function polls in a tight loop, with no sleeps. This is useful for
variables that are expected to be short-lived (e.g., a 'Done' state
immediately before a process exits).
Args:
name: the name of the process that we're trying to poll vars from.
port: the port number that we should poll for variables.
condition_msg: string describing the conditions that we're polling for,
used for error messaging.
timeout: number of seconds that we should attempt to poll for.
condition_fn: a function that takes the debug vars dict as input, and
returns a truthy value if it matches the success conditions.
require_vars: True iff we expect the vars to always exist. If
True, and the vars don't exist, we'll raise a TestError. This
can be used to differentiate between a timeout waiting for a
particular condition vs if the process that you're polling has
already exited.
Raises:
TestError: if the conditions aren't met within the given timeout, or
if vars are required and don't exist.
Returns:
dict of debug variables
"""
start_time = time.time()
while True:
if (time.time() - start_time) >= timeout:
raise TestError(
'Timed out polling for vars from %s; condition "%s" not met' %
(name, condition_msg))
v = get_vars(port)
if v is None:
if require_vars:
raise TestError(
'Expected vars to exist on %s, but they do not; '
'process probably exited earlier than expected.' % (name,))
continue
if condition_fn is None:
return v
elif condition_fn(v):
return v
def apply_vschema(vschema):
for k, v in vschema.iteritems():
fname = os.path.join(environment.tmproot, 'vschema.json')
with open(fname, 'w') as f:
f.write(v)
run_vtctl(['ApplyVSchema', '-vschema_file', fname, k])
def wait_for_tablet_type(tablet_alias, expected_type, timeout=10):
"""Waits for a given tablet's SlaveType to become the expected value.
Args:
tablet_alias: Alias of the tablet.
expected_type: Type of the tablet e.g. "replica".
timeout: Timeout in seconds.
Raises:
TestError: SlaveType did not become expected_type within timeout seconds.
"""
type_as_int = topodata_pb2.TabletType.Value(expected_type.upper())
while True:
if run_vtctl_json(['GetTablet', tablet_alias])['type'] == type_as_int:
logging.debug('tablet %s went to expected type: %s',
tablet_alias, expected_type)
break
timeout = wait_step(
"%s's SlaveType to be %s" % (tablet_alias, expected_type),
timeout)
def wait_for_replication_pos(tablet_a, tablet_b, timeout=60.0):
"""Waits for tablet B to catch up to the replication position of tablet A.
Args:
tablet_a: tablet Object for tablet A.
tablet_b: tablet Object for tablet B.
timeout: Timeout in seconds.
Raises:
TestError: replication position did not catch up within timeout seconds.
"""
replication_pos_a = mysql_flavor().master_position(tablet_a)
while True:
replication_pos_b = mysql_flavor().master_position(tablet_b)
if mysql_flavor().position_at_least(replication_pos_b, replication_pos_a):
break
timeout = wait_step(
"%s's replication position to catch up %s's; "
'currently at: %s, waiting to catch up to: %s' % (
tablet_b.tablet_alias, tablet_a.tablet_alias, replication_pos_b,
replication_pos_a),
timeout, sleep_time=0.1)
# Save the first running instance of vtgate. It is saved when 'start'
# is called, and cleared when kill is called.
vtgate = None
class VtGate(object):
"""VtGate object represents a vtgate process."""
def __init__(self, port=None, mysql_server=False):
"""Creates the Vtgate instance and reserve the ports if necessary."""
self.port = port or environment.reserve_ports(1)
if protocols_flavor().vtgate_protocol() == 'grpc':
self.grpc_port = environment.reserve_ports(1)
self.proc = None
self.mysql_port = None
if mysql_server:
self.mysql_port = environment.reserve_ports(1)
def start(self, cell='test_nj', retry_count=2,
topo_impl=None, cache_ttl='1s',
extra_args=None, tablets=None,
tablet_types_to_wait='MASTER,REPLICA',
l2vtgates=None,
cells_to_watch=None):
"""Start vtgate. Saves it into the global vtgate variable if not set yet."""
args = environment.binary_args('vtgate') + [
'-port', str(self.port),
'-cell', cell,
'-retry-count', str(retry_count),
'-log_dir', environment.vtlogroot,
'-srv_topo_cache_ttl', cache_ttl,
'-srv_topo_cache_refresh', cache_ttl,
'-tablet_protocol', protocols_flavor().tabletconn_protocol(),
'-stderrthreshold', get_log_level(),
'-normalize_queries',
'-gateway_implementation', vtgate_gateway_flavor().flavor(),
]
if cells_to_watch:
args.extend(vtgate_gateway_flavor().flags(cell=cells_to_watch, tablets=tablets))
else:
args.extend(vtgate_gateway_flavor().flags(cell=cell, tablets=tablets))
if l2vtgates:
args.extend(['-l2vtgate_addrs', ','.join(l2vtgates)])
if tablet_types_to_wait:
args.extend(['-tablet_types_to_wait', tablet_types_to_wait])
if protocols_flavor().vtgate_protocol() == 'grpc':
args.extend(['-grpc_port', str(self.grpc_port)])
args.extend(['-grpc_max_message_size',
str(environment.grpc_max_message_size)])
if protocols_flavor().service_map():
args.extend(['-service_map', ','.join(protocols_flavor().service_map())])
if topo_impl:
args.extend(['-topo_implementation', topo_impl])
else:
args.extend(environment.topo_server().flags())
if extra_args:
args.extend(extra_args)
if self.mysql_port:
args.extend(['-mysql_server_port', str(self.mysql_port)])
self.proc = run_bg(args)
# We use a longer timeout here, as we may be waiting for the initial
# state of a few tablets.
wait_for_vars('vtgate', self.port, timeout=20.0)
global vtgate
if not vtgate:
vtgate = self
def kill(self):
"""Terminates the vtgate process, and waits for it to exit.
If this process is the one saved in the global vtgate variable,
clears it.
Note if the test is using just one global vtgate process, and
starting it with the test, and killing it at the end of the test,
there is no need to call this kill() method,
utils.kill_sub_processes() will do a good enough job.
"""
if self.proc is None:
return
kill_sub_process(self.proc, soft=True)
self.proc.wait()
self.proc = None
global vtgate
if vtgate == self:
vtgate = None
def addr(self):
"""Returns the address of the vtgate process, for web access."""
return 'localhost:%d' % self.port
def rpc_endpoint(self, python=False):
"""Returns the protocol and endpoint to use for RPCs."""
if python:
protocol = protocols_flavor().vtgate_python_protocol()
else:
protocol = protocols_flavor().vtgate_protocol()
if protocol == 'grpc':
return protocol, 'localhost:%d' % self.grpc_port
return protocol, self.addr()
def get_status(self):
"""Returns the status page for this process."""
return get_status(self.port)
def get_vars(self):
"""Returns the vars for this process."""
return get_vars(self.port)
def get_vschema(self):
"""Returns the used vschema for this process."""
return urllib2.urlopen('http://localhost:%d/debug/vschema' %
self.port).read()
@contextlib.contextmanager
def create_connection(self):
"""Connects to vtgate and allows to create a cursor to execute queries.
This method is preferred over the two other methods ("vtclient", "execute")
to execute a query in tests.
Yields:
A vtgate connection object.
Example:
with self.vtgate.create_connection() as conn:
c = conn.cursor(keyspace=KEYSPACE, shards=[SHARD], tablet_type='master',
writable=self.writable)
c.execute('SELECT * FROM buffer WHERE id = :id', {'id': 1})
"""
protocol, endpoint = self.rpc_endpoint(python=True)
# Use a very long timeout to account for slow tests.
conn = vtgate_client.connect(protocol, endpoint, 600.0)
yield conn
conn.close()
@contextlib.contextmanager
def write_transaction(self, **kwargs):
"""Begins a write transaction and commits automatically.
Note that each transaction contextmanager will create a new connection.
Args:
**kwargs: vtgate cursor args. See vtgate_cursor.VTGateCursor.
Yields:
A writable vtgate cursor.
Example:
with utils.vtgate.write_transaction(keyspace=KEYSPACE, shards=[SHARD],
tablet_type='master') as tx:
tx.execute('INSERT INTO table1 (id, msg) VALUES (:id, :msg)',
{'id': 1, 'msg': 'msg1'})
"""
with self.create_connection() as conn:
cursor = conn.cursor(writable=True, **kwargs)
cursor.begin()
yield cursor
cursor.commit()
def vtclient(self, sql, keyspace=None, tablet_type='master',
bindvars=None, streaming=False,
verbose=False, raise_on_error=True, json_output=False):
"""Uses the vtclient binary to send a query to vtgate."""
protocol, addr = self.rpc_endpoint()
args = environment.binary_args('vtclient') + [
'-server', addr,
'-vtgate_protocol', protocol]
if json_output:
args.append('-json')
if bindvars:
args.extend(['-bind_variables', json.dumps(bindvars)])
if streaming:
args.append('-streaming')
if keyspace:
args.extend(['-target', '%s@%s' % (keyspace, tablet_type)])
else:
args.extend(['-target', '@'+tablet_type])
if verbose:
args.append('-alsologtostderr')
args.append(sql)
out, err = run(args, raise_on_error=raise_on_error, trap_output=True)
if json_output:
return json.loads(out), err
return out, err
def execute(self, sql, tablet_type='master', bindvars=None,
execute_options=None):
"""Uses 'vtctl VtGateExecute' to execute a command.
Args:
sql: the command to execute.
tablet_type: the tablet_type to use.
bindvars: a dict of bind variables.
execute_options: proto-encoded ExecuteOptions object.
Returns:
the result of running vtctl command.
"""
_, addr = self.rpc_endpoint()
args = ['VtGateExecute', '-json',
'-server', addr,
'-target', '@'+tablet_type]
if bindvars:
args.extend(['-bind_variables', json.dumps(bindvars)])
if execute_options:
args.extend(['-options', execute_options])
args.append(sql)
return run_vtctl_json(args)
def execute_shards(self, sql, keyspace, shards, tablet_type='master',
bindvars=None):
"""Uses 'vtctl VtGateExecuteShards' to execute a command."""
_, addr = self.rpc_endpoint()
args = ['VtGateExecuteShards', '-json',
'-server', addr,
'-keyspace', keyspace,
'-shards', shards,
'-tablet_type', tablet_type]
if bindvars:
args.extend(['-bind_variables', json.dumps(bindvars)])
args.append(sql)
return run_vtctl_json(args)
def split_query(self, sql, keyspace, split_count, bindvars=None):
"""Uses 'vtctl VtGateSplitQuery' to cut a query up in chunks."""
_, addr = self.rpc_endpoint()
args = ['VtGateSplitQuery',
'-server', addr,
'-keyspace', keyspace,
'-split_count', str(split_count)]
if bindvars:
args.extend(['-bind_variables', json.dumps(bindvars)])
args.append(sql)
return run_vtctl_json(args)
def wait_for_endpoints(self, name, count, timeout=20.0, var=None):
"""waits until vtgate gets endpoints.
Args:
name: name of the endpoint, in the form: 'keyspace.shard.type'.
count: how many endpoints to wait for.
timeout: how long to wait.
var: name of the variable to use. if None, defaults to the gateway's.
"""
wait_for_vars('vtgate', self.port,
var=var or vtgate_gateway_flavor().connection_count_vars(),
key=name, value=count, timeout=timeout)
def verify_no_endpoint(self, name):
"""verifies the vtgate doesn't have any enpoint of the given name.
Args:
name: name of the endpoint, in the form: 'keyspace.shard.type'.
"""
def condition(v):
return (v.get(vtgate_gateway_flavor().connection_count_vars())
.get(name, None)) is None
poll_for_vars('l2vtgate', self.port,
'no endpoint named ' + name,
timeout=5.0,
condition_fn=condition)
# vtctl helpers
# The modes are not all equivalent, and we don't really thrive for it.
# If a client needs to rely on vtctl's command line behavior, make
# sure to use mode=utils.VTCTL_VTCTL
VTCTL_AUTO = 0
VTCTL_VTCTL = 1
VTCTL_VTCTLCLIENT = 2
VTCTL_RPC = 3
def run_vtctl(clargs, auto_log=False, expect_fail=False,
mode=VTCTL_AUTO, action_timeout=10, **kwargs):
if mode == VTCTL_AUTO:
if not expect_fail and vtctld:
mode = VTCTL_RPC
else:
mode = VTCTL_VTCTL
if mode == VTCTL_VTCTL:
return run_vtctl_vtctl(clargs, auto_log=auto_log,
expect_fail=expect_fail, **kwargs)
elif mode == VTCTL_VTCTLCLIENT:
result = vtctld.vtctl_client(clargs)
return result, ''
elif mode == VTCTL_RPC:
if auto_log:
logging.debug('vtctl: %s', ' '.join(clargs))
result = vtctl_client.execute_vtctl_command(vtctld_connection, clargs,
info_to_debug=True,
action_timeout=action_timeout)
return result, ''
raise Exception('Unknown mode: %s', mode)
def run_vtctl_vtctl(clargs, auto_log=False, expect_fail=False,
**kwargs):
args = environment.binary_args('vtctl') + [
'-log_dir', environment.vtlogroot,
'-enable_queries',
]
args.extend(environment.topo_server().flags())
args.extend(['-tablet_manager_protocol',
protocols_flavor().tablet_manager_protocol()])
args.extend(['-tablet_protocol', protocols_flavor().tabletconn_protocol()])
args.extend(['-throttler_client_protocol',
protocols_flavor().throttler_client_protocol()])
args.extend(['-vtgate_protocol', protocols_flavor().vtgate_protocol()])
# TODO(b/26388813): Remove the next two lines once vtctl WaitForDrain is
# integrated in the vtctl MigrateServed* commands.
args.extend(['--wait_for_drain_sleep_rdonly', '0s'])
args.extend(['--wait_for_drain_sleep_replica', '0s'])
if auto_log:
args.append('--stderrthreshold=%s' % get_log_level())
if isinstance(clargs, str):
cmd = ' '.join(args) + ' ' + clargs
else:
cmd = args + clargs
if expect_fail:
return run_fail(cmd, **kwargs)
return run(cmd, **kwargs)
# run_vtctl_json runs the provided vtctl command and returns the result
# parsed as json
def run_vtctl_json(clargs, auto_log=True):
stdout, _ = run_vtctl(clargs, trap_output=True, auto_log=auto_log)
return json.loads(stdout)
def get_log_level():
if options.verbose == 2:
return '0'
elif options.verbose == 1:
return '1'
else:
return '2'
# vtworker helpers
def run_vtworker(clargs, auto_log=False, expect_fail=False, **kwargs):
"""Runs a vtworker process, returning the stdout and stderr."""
cmd, _, _ = _get_vtworker_cmd(clargs, auto_log)
if expect_fail:
return run_fail(cmd, **kwargs)
return run(cmd, **kwargs)
def run_vtworker_bg(clargs, auto_log=False, **kwargs):
"""Starts a background vtworker process."""
cmd, port, rpc_port = _get_vtworker_cmd(clargs, auto_log)
proc = run_bg(cmd, **kwargs), port, rpc_port
wait_for_vars('vtworker', port)
return proc
def _get_vtworker_cmd(clargs, auto_log=False):
"""Assembles the command that is needed to run a vtworker.
Args:
clargs: Command line arguments passed to vtworker.
auto_log: If true, set --stderrthreshold according to the test log level.
Returns:
cmd - list of cmd arguments, can be passed to any `run`-like functions
port - int with the port number that the vtworker is running with
rpc_port - int with the port number of the RPC interface
"""
port = environment.reserve_ports(1)
rpc_port = port
args = environment.binary_args('vtworker') + [
'-log_dir', environment.vtlogroot,
'-port', str(port),
'-executefetch_retry_time', '1s',
'-tablet_manager_protocol',
protocols_flavor().tablet_manager_protocol(),
'-tablet_protocol', protocols_flavor().tabletconn_protocol(),
]
args.extend(environment.topo_server().flags())
if protocols_flavor().service_map():
args.extend(['-service_map',
','.join(protocols_flavor().service_map())])
if protocols_flavor().vtworker_client_protocol() == 'grpc':
rpc_port = environment.reserve_ports(1)
args.extend(['-grpc_port', str(rpc_port)])
if auto_log:
args.append('--stderrthreshold=%s' % get_log_level())
cmd = args + clargs
return cmd, port, rpc_port
# vtworker client helpers
def run_vtworker_client_bg(args, rpc_port):
"""Runs vtworkerclient to execute a command on a remote vtworker.
Args:
args: Full vtworker command.
rpc_port: Port number.
Returns:
proc: process returned by subprocess.Popen
"""
return run_bg(
environment.binary_args('vtworkerclient') + [
'-log_dir', environment.vtlogroot,
'-vtworker_client_protocol',
protocols_flavor().vtworker_client_protocol(),
'-server', 'localhost:%d' % rpc_port,
'-stderrthreshold', get_log_level(),
] + args)
def run_automation_server(auto_log=False):
"""Starts a background automation_server process.
Args:
auto_log: True to log.
Returns:
rpc_port - int with the port number of the RPC interface
"""
rpc_port = environment.reserve_ports(1)
args = environment.binary_args('automation_server') + [
'-log_dir', environment.vtlogroot,
'-port', str(rpc_port),
'-vtctl_client_protocol',
protocols_flavor().vtctl_client_protocol(),
'-vtworker_client_protocol',
protocols_flavor().vtworker_client_protocol(),
]
if auto_log:
args.append('--stderrthreshold=%s' % get_log_level())
return run_bg(args), rpc_port
# mysql helpers
def mysql_query(uid, dbname, query):
conn = MySQLdb.Connect(
user='vt_dba',
unix_socket='%s/vt_%010d/mysql.sock' % (environment.vtdataroot, uid),
db=dbname)
cursor = conn.cursor()
cursor.execute(query)
try:
return cursor.fetchall()
finally:
conn.close()
def mysql_write_query(uid, dbname, query):
conn = MySQLdb.Connect(
user='vt_dba',
unix_socket='%s/vt_%010d/mysql.sock' % (environment.vtdataroot, uid),
db=dbname)
cursor = conn.cursor()
conn.begin()
cursor.execute(query)
conn.commit()
try:
return cursor.fetchall()
finally:
conn.close()
def check_db_var(uid, name, value):
conn = MySQLdb.Connect(
user='vt_dba',
unix_socket='%s/vt_%010d/mysql.sock' % (environment.vtdataroot, uid))
cursor = conn.cursor()
cursor.execute("show variables like '%s'" % name)
row = cursor.fetchone()
if row != (name, value):
raise TestError('variable not set correctly', name, row)
conn.close()
def check_db_read_only(uid):
return check_db_var(uid, 'read_only', 'ON')
def check_db_read_write(uid):
return check_db_var(uid, 'read_only', 'OFF')
def wait_db_read_only(uid):
for _ in xrange(3):
try:
check_db_read_only(uid)
return
except TestError as e:
logging.warning('wait_db_read_only: %s', str(e))
time.sleep(1.0)
raise e
def check_srv_keyspace(cell, keyspace, expected, keyspace_id_type='uint64',
sharding_column_name='keyspace_id'):
ks = run_vtctl_json(['GetSrvKeyspace', cell, keyspace])
result = ''
pmap = {}
for partition in ks['partitions']:
tablet_type = topodata_pb2.TabletType.Name(partition['served_type']).lower()
if tablet_type == 'batch':
tablet_type = 'rdonly'
r = 'Partitions(%s):' % tablet_type
for shard in partition['shard_references']:
s = ''
e = ''
if 'key_range' in shard and shard['key_range']:
if 'start' in shard['key_range']:
s = shard['key_range']['start']
s = base64.b64decode(s).encode('hex') if s else ''
if 'end' in shard['key_range']:
e = shard['key_range']['end']
e = base64.b64decode(e).encode('hex') if e else ''
r += ' %s-%s' % (s, e)
pmap[tablet_type] = r + '\n'
for tablet_type in sorted(pmap):
result += pmap[tablet_type]
logging.debug('Cell %s keyspace %s has data:\n%s', cell, keyspace, result)
if expected != result:
raise Exception(
'Mismatch in srv keyspace for cell %s keyspace %s, expected:\n%'
's\ngot:\n%s' % (
cell, keyspace, expected, result))
if sharding_column_name != ks.get('sharding_column_name'):
raise Exception('Got wrong sharding_column_name in SrvKeyspace: %s' %
str(ks))
if keyspace_id_type != keyrange_constants.PROTO3_KIT_TO_STRING[
ks.get('sharding_column_type')]:
raise Exception('Got wrong sharding_column_type in SrvKeyspace: %s' %
str(ks))
def check_shard_query_service(
testcase, cell, keyspace, shard_name, tablet_type, expected_state):
"""Checks DisableQueryService in the shard record's TabletControlMap."""
# We assume that query service should be enabled unless
# DisableQueryService is explicitly True
query_service_enabled = True
ks = run_vtctl_json(['GetSrvKeyspace', cell, keyspace])
for partition in ks['partitions']:
tablet_type = topodata_pb2.TabletType.Name(partition['served_type'])
if tablet_type != tablet_type:
continue
for shard in partition['shard_tablet_controls']:
if shard['name'] == shard_name:
if shard['query_service_disabled']:
query_service_enabled = False
testcase.assertEqual(
query_service_enabled,
expected_state,
'shard %s does not have the correct query service state: '
'got %s but expected %s' %
(shard_name, query_service_enabled, expected_state)
)
def check_shard_query_services(
testcase, cell, keyspace, shard_names, tablet_type, expected_state):
for shard_name in shard_names:
check_shard_query_service(
testcase, cell, keyspace, shard_name, tablet_type, expected_state)
def check_tablet_query_service(
testcase, tablet, serving, tablet_control_disabled):
"""Check that the query service is enabled or disabled on the tablet."""
tablet_vars = get_vars(tablet.port)
if serving:
expected_state = 'SERVING'
else:
expected_state = 'NOT_SERVING'
testcase.assertEqual(
tablet_vars['TabletStateName'], expected_state,
'tablet %s (%s/%s, %s) is not in the right serving state: got %s'
' expected %s' % (tablet.tablet_alias, tablet.keyspace, tablet.shard,
tablet.tablet_type,
tablet_vars['TabletStateName'], expected_state))
status = tablet.get_status()
tc_dqs = 'Query Service disabled: TabletControl.DisableQueryService set'
if tablet_control_disabled:
testcase.assertIn(tc_dqs, status)
else:
testcase.assertNotIn(tc_dqs, status)
if tablet.tablet_type == 'rdonly':
# Run RunHealthCheck to be sure the tablet doesn't change its serving state.
run_vtctl(['RunHealthCheck', tablet.tablet_alias],
auto_log=True)
tablet_vars = get_vars(tablet.port)
testcase.assertEqual(
tablet_vars['TabletStateName'], expected_state,
'tablet %s is not in the right serving state after health check: '
'got %s expected %s' %
(tablet.tablet_alias, tablet_vars['TabletStateName'], expected_state))
def check_tablet_query_services(
testcase, tablets, serving, tablet_control_disabled):
for tablet in tablets:
check_tablet_query_service(
testcase, tablet, serving, tablet_control_disabled)
def get_status(port):
return urllib2.urlopen(
'http://localhost:%d%s' % (port, environment.status_url)).read()
def curl(url, request=None, data=None, background=False, retry_timeout=0,
**kwargs):
args = [environment.curl_bin, '--silent', '--no-buffer', '--location']
if not background:
args.append('--show-error')
if request:
args.extend(['--request', request])
if data:
args.extend(['--data', data])
args.append(url)
if background:
return run_bg(args, **kwargs)
if retry_timeout > 0:
while True:
try:
return run(args, trap_output=True, **kwargs)
except TestError as e:
retry_timeout = wait_step(
'cmd: %s, error: %s' % (str(args), str(e)), retry_timeout)
return run(args, trap_output=True, **kwargs)
class VtctldError(Exception):
pass
# save the first running instance, and an RPC connection to it,
# so we can use it to run remote vtctl commands
vtctld = None
vtctld_connection = None
class Vtctld(object):
def __init__(self):
self.port = environment.reserve_ports(1)
self.schema_change_dir = os.path.join(
environment.tmproot, 'schema_change_test')
if protocols_flavor().vtctl_client_protocol() == 'grpc':
self.grpc_port = environment.reserve_ports(1)
def start(self, enable_schema_change_dir=False, extra_flags=None):
# Note the vtctld2 web dir is set to 'dist', which is populated
# when a toplevel 'make build_web' is run. This is meant to test
# the development version of the UI. The real checked-in app is in
# app/.
args = environment.binary_args('vtctld') + [
'-enable_queries',
'-cell', 'test_nj',
'--log_dir', environment.vtlogroot,
'--port', str(self.port),
'-tablet_manager_protocol',
protocols_flavor().tablet_manager_protocol(),
'-tablet_protocol', protocols_flavor().tabletconn_protocol(),
'-throttler_client_protocol',
protocols_flavor().throttler_client_protocol(),
'-vtgate_protocol', protocols_flavor().vtgate_protocol(),
'-workflow_manager_init',
'-workflow_manager_use_election',
'-schema_swap_delay_between_errors', '1s',
'-wait_for_drain_sleep_rdonly', '1s',
'-wait_for_drain_sleep_replica', '1s',
] + environment.topo_server().flags()
if extra_flags:
args += extra_flags
# TODO(b/26388813): Remove the next two lines once vtctl WaitForDrain is
# integrated in the vtctl MigrateServed* commands.
args.extend(['--wait_for_drain_sleep_rdonly', '0s'])
args.extend(['--wait_for_drain_sleep_replica', '0s'])
if enable_schema_change_dir:
args += [
'--schema_change_dir', self.schema_change_dir,
'--schema_change_controller', 'local',
'--schema_change_check_interval', '1',
]
if protocols_flavor().service_map():
args.extend(['-service_map', ','.join(protocols_flavor().service_map())])
if protocols_flavor().vtctl_client_protocol() == 'grpc':
args.extend(['-grpc_port', str(self.grpc_port)])
stdout_fd = open(os.path.join(environment.tmproot, 'vtctld.stdout'), 'w')
stderr_fd = open(os.path.join(environment.tmproot, 'vtctld.stderr'), 'w')
self.proc = run_bg(args, stdout=stdout_fd, stderr=stderr_fd)
# wait for the process to listen to RPC
timeout = 30
while True:
v = get_vars(self.port)
if v:
break
if self.proc.poll() is not None:
raise TestError('vtctld died while starting')
timeout = wait_step('waiting for vtctld to start', timeout,
sleep_time=0.2)
# save the running instance so vtctl commands can be remote executed now
global vtctld, vtctld_connection
if not vtctld:
vtctld = self
protocol, endpoint = self.rpc_endpoint(python=True)
vtctld_connection = vtctl_client.connect(protocol, endpoint, 30)
return self.proc
def rpc_endpoint(self, python=False):
"""RPC endpoint to vtctld.
The RPC endpoint may differ from the webinterface URL e.g. because gRPC
requires a dedicated port.
Args:
python: boolean, True iff this is for access with Python (as opposed to
Go).
Returns:
protocol - string e.g. 'grpc'
endpoint - string e.g. 'localhost:15001'
"""
if python:
protocol = protocols_flavor().vtctl_python_client_protocol()
else:
protocol = protocols_flavor().vtctl_client_protocol()
rpc_port = self.port
if protocol == 'grpc':
rpc_port = self.grpc_port
return (protocol, '%s:%d' % (socket.getfqdn(), rpc_port))
def process_args(self):
return ['-vtctld_addr', 'http://localhost:%d/' % self.port]
def vtctl_client(self, args):
if options.verbose == 2:
log_level = 'INFO'
elif options.verbose == 1:
log_level = 'WARNING'
else:
log_level = 'ERROR'
protocol, endpoint = self.rpc_endpoint()
out, _ = run(
environment.binary_args('vtctlclient') +
['-vtctl_client_protocol', protocol,
'-server', endpoint,
'-stderrthreshold', log_level] + args,
trap_output=True)
return out
def uint64_to_hex(integer):
"""Returns the hex representation of an int treated as a 64-bit unsigned int.
The result is padded by zeros if necessary to fill a 16 character string.
Useful for converting keyspace ids integers.
Example:
uint64_to_hex(1) == "0000000000000001"
uint64_to_hex(0xDEADBEAF) == "00000000DEADBEEF"
uint64_to_hex(0xDEADBEAFDEADBEAFDEADBEAF) raises an out of range exception.
Args:
integer: the value to print.
Returns:
String with the hex representation.
Raises:
ValueError: if the integer is out of range.
"""
if integer > (1<<64)-1 or integer < 0:
raise ValueError('Integer out of range: %d' % integer)
return '%016X' % integer