Fix some basic lint issues in vitess/py.

This commit is contained in:
Dean Yasuda 2015-08-21 18:51:18 -07:00
Родитель 2dfda59b5a
Коммит a78bbf14ab
40 изменённых файлов: 796 добавлений и 522 удалений

Просмотреть файл

@ -1,9 +1,10 @@
from distutils.core import setup, Extension
from distutils.core import Extension
from distutils.core import setup
cbson = Extension('cbson',
sources = ['cbson.c'])
sources=['cbson.c'])
setup(name = 'cbson',
version = '0.1',
description = 'Fast BSON decoding via C',
ext_modules = [cbson])
setup(name='cbson',
version='0.1',
description='Fast BSON decoding via C',
ext_modules=[cbson])

Просмотреть файл

@ -12,16 +12,20 @@ import sys
sys.path.append(os.path.split(os.path.abspath(__file__))[0])
import cbson
def test_load_empty():
"""
"""Doctest.
>>> cbson.loads('')
Traceback (most recent call last):
...
BSONError: empty buffer
"""
def test_load_binary():
r"""
r"""Doctest.
>>> s = cbson.dumps({'world': 'hello'})
>>> s
'\x16\x00\x00\x00\x05world\x00\x05\x00\x00\x00\x00hello\x00'
@ -29,8 +33,10 @@ def test_load_binary():
{'world': 'hello'}
"""
def test_load_string():
r"""
r"""Doctest.
>>> s = cbson.dumps({'world': u'hello \u00fc'})
>>> s
'\x19\x00\x00\x00\x02world\x00\t\x00\x00\x00hello \xc3\xbc\x00\x00'
@ -38,8 +44,10 @@ def test_load_string():
{'world': u'hello \xfc'}
"""
def test_load_int():
r"""
r"""Doctest.
>>> s = cbson.dumps({'int': 1334})
>>> s
'\x0e\x00\x00\x00\x10int\x006\x05\x00\x00\x00'
@ -47,9 +55,11 @@ def test_load_int():
{'int': 1334}
"""
# the library doesn't allow creation of these, so just check the unpacking
def test_unpack_uint64():
r"""
r"""Doctest.
>>> s = '\x12\x00\x00\x00\x3fint\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00'
>>> cbson.loads(s)
{'int': 1L}
@ -58,16 +68,20 @@ def test_unpack_uint64():
{'int': 10376293541461622784L}
"""
def test_bool():
"""
"""Doctest.
>>> cbson.loads(cbson.dumps({'yes': True, 'no': False}))['yes']
True
>>> cbson.loads(cbson.dumps({'yes': True, 'no': False}))['no']
False
"""
def test_none():
r"""
r"""Doctest.
>>> s = cbson.dumps({'none': None})
>>> s
'\x0b\x00\x00\x00\nnone\x00\x00'
@ -75,8 +89,10 @@ def test_none():
{'none': None}
"""
def test_array():
r"""
r"""Doctest.
>>> s = cbson.dumps({'lst': [1, 2, 3, 4, 'hello', None]})
>>> s
';\x00\x00\x00\x04lst\x001\x00\x00\x00\x100\x00\x01\x00\x00\x00\x101\x00\x02\x00\x00\x00\x102\x00\x03\x00\x00\x00\x103\x00\x04\x00\x00\x00\x054\x00\x05\x00\x00\x00\x00hello\n5\x00\x00\x00'
@ -84,15 +100,19 @@ def test_array():
{'lst': [1, 2, 3, 4, 'hello', None]}
"""
def test_dict():
"""
"""Doctest.
>>> d = {'a': None, 'b': 2, 'c': 'hello', 'd': 1.02, 'e': ['a', 'b', 'c']}
>>> cbson.loads(cbson.dumps(d)) == d
True
"""
def test_nested_dict():
r"""
r"""Doctest.
>>> s = cbson.dumps({'a': {'b': {'c': 0}}})
>>> s
'\x1c\x00\x00\x00\x03a\x00\x14\x00\x00\x00\x03b\x00\x0c\x00\x00\x00\x10c\x00\x00\x00\x00\x00\x00\x00\x00'
@ -100,8 +120,10 @@ def test_nested_dict():
{'a': {'b': {'c': 0}}}
"""
def test_dict_in_array():
r"""
r"""Doctest.
>>> s = cbson.dumps({'a': [{'b': 0}, {'c': 1}]})
>>> s
'+\x00\x00\x00\x04a\x00#\x00\x00\x00\x030\x00\x0c\x00\x00\x00\x10b\x00\x00\x00\x00\x00\x00\x031\x00\x0c\x00\x00\x00\x10c\x00\x01\x00\x00\x00\x00\x00\x00'
@ -109,32 +131,40 @@ def test_dict_in_array():
{'a': [{'b': 0}, {'c': 1}]}
"""
def test_eob1():
"""
"""Doctest.
>>> cbson.loads(BSON(BSON_TAG(1), NULL_BYTE))
Traceback (most recent call last):
...
BSONError: unexpected end of buffer: wanted 8 bytes at buffer[6] for double
"""
def test_eob2():
"""
"""Doctest.
>>> cbson.loads(BSON(BSON_TAG(2), NULL_BYTE))
Traceback (most recent call last):
...
BSONError: unexpected end of buffer: wanted 4 bytes at buffer[6] for string-length
"""
def test_eob_cstring():
"""
"""Doctest.
>>> cbson.loads(BSON(BSON_TAG(1)))
Traceback (most recent call last):
...
BSONError: unexpected end of buffer: non-terminated cstring at buffer[5] for element-name
"""
def test_decode_next():
"""
"""Doctest.
>>> s = cbson.dumps({'a': 1}) + cbson.dumps({'b': 2.0}) + cbson.dumps({'c': [None]})
>>> cbson.decode_next(s)
(12, {'a': 1})
@ -144,8 +174,10 @@ def test_decode_next():
(44, {'c': [None]})
"""
def test_decode_next_eob():
"""
"""Doctest.
>>> s_full = cbson.dumps({'a': 1}) + cbson.dumps({'b': 2.0})
>>> s = s_full[:-1]
>>> cbson.decode_next(s)
@ -164,8 +196,10 @@ def test_decode_next_eob():
BSONBufferTooShort: ('buffer too short: buffer[12:] does not contain 12 bytes for document', 2)
"""
def test_encode_recursive():
"""
"""Doctest.
>>> a = []
>>> a.append(a)
>>> cbson.dumps({"x": a})
@ -174,8 +208,10 @@ def test_encode_recursive():
ValueError: object too deeply nested to BSON encode
"""
def test_encode_invalid():
"""
"""Doctest.
>>> cbson.dumps(object())
Traceback (most recent call last):
...
@ -186,24 +222,26 @@ def test_encode_invalid():
TypeError: unsupported type for BSON encode
"""
def BSON(*l):
buf = ''.join(l)
return struct.pack('i', len(buf)+4) + buf
def BSON_TAG(type_id):
return struct.pack('b', type_id)
NULL_BYTE = '\x00'
if __name__ == "__main__":
if __name__ == '__main__':
cbson
import doctest
print "Running selftest:"
print 'Running selftest:'
status = doctest.testmod(sys.modules[__name__])
if status[0]:
print "*** %s tests of %d failed." % status
print '*** %s tests of %d failed.' % status
sys.exit(1)
else:
print "--- %s tests passed." % status[1]
print '--- %s tests passed.' % status[1]
sys.exit(0)

Просмотреть файл

@ -12,15 +12,18 @@ import cbson
# crash the decoder, before we fixed the issues.
# Only base64 encoded so they aren't a million characters long and/or full
# of escape sequences.
KNOWN_BAD = ["VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAB0A",
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAADTMwAEAAAAEDQABQAAAAU"
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAAAA",
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQ2gDAAAARFMAAAAA",
]
KNOWN_BAD = [
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAB0A",
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAADTMwAEAAAAEDQABQAAAAU"
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAAAA",
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQ2gDAAAARFMAAAAA",
]
class CbsonTest(unittest.TestCase):
def test_short_string_segfaults(self):
a = cbson.dumps({"A": [1, 2, 3, 4, 5, "6", u"7", {"C": u"DS"}]})
for i in range(len(a))[1:]:
@ -40,7 +43,7 @@ class CbsonTest(unittest.TestCase):
def test_random_segfaults(self):
a = cbson.dumps({"A": [1, 2, 3, 4, 5, "6", u"7", {"C": u"DS"}]})
sys.stdout.write("\nQ: %s\n" % (binascii.hexlify(a),))
for i in range(1000):
for _ in range(1000):
l = [c for c in a]
l[random.randint(4, len(a)-1)] = chr(random.randint(0, 255))
try:
@ -54,4 +57,3 @@ class CbsonTest(unittest.TestCase):
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -4,9 +4,10 @@
# Go-style RPC client using BSON as the codec.
import bson
import hmac
import struct
import bson
try:
# use optimized cbson which has slightly different API
import cbson
@ -25,11 +26,14 @@ len_struct = struct.Struct('<i')
unpack_length = len_struct.unpack_from
len_struct_size = len_struct.size
class BsonRpcClient(gorpc.GoRpcClient):
def __init__(self, addr, timeout, user=None, password=None,
keyfile=None, certfile=None):
if bool(user) != bool(password):
raise ValueError("You must provide either both or none of user and password.")
raise ValueError(
'You must provide either both or none of user and password.')
self.addr = addr
self.user = user
self.password = password
@ -37,7 +41,8 @@ class BsonRpcClient(gorpc.GoRpcClient):
uri = 'http://%s/_bson_rpc_/auth' % self.addr
else:
uri = 'http://%s/_bson_rpc_' % self.addr
gorpc.GoRpcClient.__init__(self, uri, timeout, keyfile=keyfile, certfile=certfile)
gorpc.GoRpcClient.__init__(
self, uri, timeout, keyfile=keyfile, certfile=certfile)
def dial(self):
gorpc.GoRpcClient.dial(self)
@ -49,10 +54,11 @@ class BsonRpcClient(gorpc.GoRpcClient):
raise
def authenticate(self):
challenge = self.call('AuthenticatorCRAMMD5.GetNewChallenge', "").reply['Challenge']
challenge = self.call(
'AuthenticatorCRAMMD5.GetNewChallenge', '').reply['Challenge']
# CRAM-MD5 authentication.
proof = self.user + " " + hmac.HMAC(self.password, challenge).hexdigest()
self.call('AuthenticatorCRAMMD5.Authenticate', {"Proof": proof})
proof = self.user + ' ' + hmac.HMAC(self.password, challenge).hexdigest()
self.call('AuthenticatorCRAMMD5.Authenticate', {'Proof': proof})
def encode_request(self, req):
try:
@ -81,7 +87,7 @@ class BsonRpcClient(gorpc.GoRpcClient):
# decode the payload length and see if we have enough
body_len = unpack_length(data, header_len)[0]
if data_len < header_len + body_len:
return None, header_len + body_len - data_len
return None, header_len + body_len - data_len
# we have enough data, decode it all
try:

Просмотреть файл

@ -9,16 +9,18 @@
import errno
import select
import ssl
import socket
import ssl
import time
import urlparse
_lastStreamResponseError = 'EOS'
class GoRpcError(Exception):
pass
class TimeoutError(GoRpcError):
pass
@ -40,8 +42,8 @@ def make_header(method, sequence_id):
class GoRpcRequest(object):
header = None # standard fields that route the request on the server side
body = None # the actual request object - usually a dictionary
header = None # standard fields that route the request on the server side
body = None # the actual request object - usually a dictionary
def __init__(self, header, args):
self.header = header
@ -58,7 +60,7 @@ class GoRpcResponse(object):
# 'Seq': sequence_id,
# 'Error': error_string}
header = None
reply = None # the decoded object - usually a dictionary
reply = None # the decoded object - usually a dictionary
@property
def error(self):
@ -70,9 +72,11 @@ class GoRpcResponse(object):
default_read_buffer_size = 8192
# A single socket wrapper to handle request/response conversation for this
# protocol. Internal, use GoRpcClient instead.
class _GoRpcConn(object):
def __init__(self, timeout):
self.conn = None
# NOTE(msolomon) since the deadlines are approximate in the code, set
@ -88,7 +92,8 @@ class _GoRpcConn(object):
conip = socket.gethostbyname(conhost)
except NameError:
conip = socket.getaddrinfo(conhost, None)[0][4][0]
self.conn = socket.create_connection((conip, int(conport)), self.socket_timeout)
self.conn = socket.create_connection(
(conip, int(conport)), self.socket_timeout)
if parts.scheme == 'https':
self.conn = ssl.wrap_socket(self.conn, keyfile=keyfile, certfile=certfile)
self.conn.sendall('CONNECT %s HTTP/1.0\n\n' % parts.path)
@ -101,7 +106,9 @@ class _GoRpcConn(object):
continue
raise
if not d:
raise GoRpcError('Unexpected EOF in handshake to %s:%s %s' % (str(conip), str(conport), parts.path))
raise GoRpcError(
'Unexpected EOF in handshake to %s:%s %s' %
(str(conip), str(conport), parts.path))
data += d
if '\n\n' in data:
return
@ -159,6 +166,7 @@ class _GoRpcConn(object):
class GoRpcClient(object):
def __init__(self, uri, timeout, certfile=None, keyfile=None):
self.uri = uri
self.timeout = timeout

Просмотреть файл

@ -4,7 +4,9 @@
from vtctl import vtctl_client, gorpc_vtctl_client
def init():
vtctl_client.register_conn_class('gorpc', gorpc_vtctl_client.GoRpcVtctlClient)
vtctl_client.register_conn_class(
'gorpc', gorpc_vtctl_client.GoRpcVtctlClient)
init()

Просмотреть файл

@ -7,9 +7,9 @@
import datetime
from urlparse import urlparse
import vtctl_client
from vtproto import vtctldata_pb2
from vtproto import vtctlservice_pb2
import vtctl_client
class GRPCVtctlClient(vtctl_client.VtctlClient):

Просмотреть файл

@ -5,13 +5,13 @@
# PEP 249 complient db api for Vitess
apilevel = '2.0'
threadsafety = 0 # Threads may not share the module because multi_client is not thread safe
# Threads may not share the module because multi_client is not thread safe.
threadsafety = 0
paramstyle = 'named'
from vtdb.dbexceptions import *
from vtdb.times import Date, Time, Timestamp, DateFromTicks, TimeFromTicks, TimestampFromTicks
from vtdb.field_types import Binary
from vtdb.field_types import STRING, BINARY, NUMBER, DATETIME, ROWID
from vtdb.vtgatev3 import *
from vtdb.cursorv3 import *
from vtdb.dbexceptions import *
from vtdb.field_types import STRING, BINARY, NUMBER, DATETIME, ROWID
from vtdb.times import Date, Time, Timestamp, DateFromTicks, TimeFromTicks, TimestampFromTicks
from vtdb.vtgatev2 import *
from vtdb.vtgatev3 import *

Просмотреть файл

@ -4,6 +4,7 @@
from vtdb import dbexceptions
class BaseCursor(object):
arraysize = 1
lastrowid = None
@ -39,7 +40,8 @@ class BaseCursor(object):
self.connection.rollback()
return
self.results, self.rowcount, self.lastrowid, self.description = self.connection._execute(sql, bind_variables, **kargs)
self.results, self.rowcount, self.lastrowid, self.description = (
self.connection._execute(sql, bind_variables, **kargs))
self.index = 0
return self.rowcount
@ -97,13 +99,16 @@ class BaseCursor(object):
raise StopIteration
return val
# A simple cursor intended for attaching to a single tablet server.
class TabletCursor(BaseCursor):
def execute(self, sql, bind_variables=None):
return self._execute(sql, bind_variables)
class BatchCursor(BaseCursor):
def __init__(self, connection):
# rowset is [(results, rowcount, lastrowid, fields),]
self.rowsets = None
@ -125,12 +130,14 @@ class BatchCursor(BaseCursor):
# just used for batch items
class BatchQueryItem(object):
def __init__(self, sql, bind_variables, key, keys):
self.sql = sql
self.bind_variables = bind_variables
self.key = key
self.keys = keys
class StreamCursor(object):
arraysize = 1
conversions = None
@ -149,7 +156,8 @@ class StreamCursor(object):
# for instance, a key value for shard mapping
def execute(self, sql, bind_variables, **kargs):
self.description = None
x, y, z, self.description = self.connection._stream_execute(sql, bind_variables, **kargs)
x, y, z, self.description = self.connection._stream_execute(
sql, bind_variables, **kargs)
self.index = 0
return 0

Просмотреть файл

@ -49,10 +49,9 @@ class Cursor(object):
self.rollback()
return
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
sql,
bind_variables,
self.tablet_type)
self.results, self.rowcount, self.lastrowid, self.description = (
self._conn._execute(
sql, bind_variables, self.tablet_type))
self.index = 0
return self.rowcount
@ -121,7 +120,7 @@ class StreamCursor(Cursor):
def execute(self, sql, bind_variables, **kargs):
self.description = None
x, y, z, self.description = self._conn._stream_execute(
_, _, _, self.description = self._conn._stream_execute(
sql,
bind_variables,
self.tablet_type)
@ -145,7 +144,7 @@ class StreamCursor(Cursor):
if self.fetchmany_done:
self.fetchmany_done = False
return result
for i in xrange(size):
for _ in xrange(size):
row = self.fetchone()
if row is None:
self.fetchmany_done = True

Просмотреть файл

@ -15,7 +15,6 @@ management.
# Use of this source code is governed by a BSD-style license that can
# be found in the LICENSE file.
import contextlib
import functools
from vtdb import dbexceptions
@ -24,13 +23,13 @@ from vtdb import vtdb_logger
from vtdb import vtgatev2
#TODO: verify that these values make sense.
# TODO: verify that these values make sense.
DEFAULT_CONNECTION_TIMEOUT = 5.0
__app_read_only_mode_method = lambda:False
__app_read_only_mode_method = lambda: False
__vtgate_connect_method = vtgatev2.connect
#TODO: perhaps make vtgate addrs also a registeration mechanism ?
#TODO: add mechansim to refresh vtgate addrs.
# TODO: perhaps make vtgate addrs also a registeration mechanism ?
# TODO: add mechansim to refresh vtgate addrs.
class DatabaseContext(object):
@ -44,18 +43,20 @@ class DatabaseContext(object):
Attributes:
lag_tolerant_mode: This directs all replica traffic to batch replicas.
This is done for applications that have a OLAP workload and also higher tolerance
for replication lag.
This is done for applications that have a OLAP workload and also higher
tolerance for replication lag.
vtgate_addrs: vtgate server endpoints
master_access_disabled: Disallow master access for application running in non-master
capable cells.
master_access_disabled: Disallow master access for application running
in non-master` capable cells.
event_logger: Logs events and errors of note. Defaults to vtdb_logger.
transaction_stack_depth: This allows nesting of transactions and makes
commit rpc to VTGate when the outer-most commits.
vtgate_connection: Connection to VTGate.
"""
def __init__(self, vtgate_addrs=None, lag_tolerant_mode=False, master_access_disabled=False):
def __init__(
self, vtgate_addrs=None, lag_tolerant_mode=False,
master_access_disabled=False):
self.vtgate_addrs = vtgate_addrs
self.lag_tolerant_mode = lag_tolerant_mode
self.master_access_disabled = master_access_disabled
@ -83,14 +84,20 @@ class DatabaseContext(object):
Transactions and some of the consistency guarantees rely on vtgate
connections being sticky hence this class caches the connection.
Returns:
A vtgate_connection.
"""
if self.vtgate_connection is not None and not self.vtgate_connection.is_closed():
if (self.vtgate_connection is not None and
not self.vtgate_connection.is_closed()):
return self.vtgate_connection
#TODO: the connect method needs to be extended to include query n txn timeouts as well
#FIXME: what is the best way of passing other params ?
# TODO: the connect method needs to be extended to include query
# n txn timeouts as well
# FIXME: what is the best way of passing other params ?
connect_method = get_vtgate_connect_method()
self.vtgate_connection = connect_method(self.vtgate_addrs, self.connection_timeout)
self.vtgate_connection = connect_method(
self.vtgate_addrs, self.connection_timeout)
return self.vtgate_connection
def degrade_master_read_to_replica(self):
@ -174,19 +181,31 @@ class DBOperationBase(object):
dc: database context object.
writable: Indicates whether this is part of write transaction.
"""
def __init__(self, db_context):
self.dc = db_context
self.writable = False
def get_cursor(self, **cursor_kargs):
"""This returns the create_cursor method of DatabaseContext with
the writable attribute from the instance of DBOperationBase's
derived classes."""
return functools.partial(self.dc.create_cursor, self.writable, **cursor_kargs)
"""This returns the create_cursor method of DatabaseContext.
DatabaseContext has the writable attribute from the instance of
DBOperationBase's derived classes.
Args:
**cursor_kargs: Arguments to be passed to create_cursor.
Returns:
The create_cursor method with self.writable as first argument and
**cursor_kargs passed through.
"""
return functools.partial(
self.dc.create_cursor, self.writable, **cursor_kargs)
class ReadFromMaster(DBOperationBase):
"""Context Manager for reading from master."""
def __enter__(self):
self.dc.read_from_master_setup()
return self
@ -203,6 +222,7 @@ class ReadFromMaster(DBOperationBase):
class ReadFromReplica(DBOperationBase):
"""Context Manager for reading from lag-sensitive or lag-tolerant replica."""
def __enter__(self):
self.dc.read_from_replica_setup()
return self
@ -219,6 +239,7 @@ class ReadFromReplica(DBOperationBase):
class WriteTransaction(DBOperationBase):
"""Context Manager for write transactions."""
def __enter__(self):
self.writable = True
self.dc.write_transaction_setup()
@ -293,6 +314,7 @@ def register_create_vtgate_connection_method(connect_method):
global __vtgate_connect_method
__vtgate_connect_method = connect_method
def get_vtgate_connect_method():
"""Returns the vtgate connection creation method."""
global __vtgate_connect_method
@ -301,6 +323,7 @@ def get_vtgate_connect_method():
# The global object is for legacy application.
__database_context = None
def open_context(*pargs, **kargs):
"""Returns the existing global database context or creates a new one."""
global __database_context

Просмотреть файл

@ -1,26 +1,24 @@
"""Module containing the base class for database classes and decorator for db method.
"""Tthe base class for database classes and decorator for db method.
The base class DBObjectBase is the base class for all other database base classes.
It has methods for common database operations like select, insert, update and delete.
This module also contains the definition for ShardRouting which is used for determining
the routing of a query during cursor creation.
The module also has the db_class_method decorator and db_wrapper method which are
used for cursor creation and calling the database method.
The base class DBObjectBase is the base class for all other database
base classes. It has methods for common database operations like
select, insert, update and delete. This module also contains the
definition for ShardRouting which is used for determining the routing
of a query during cursor creation. The module also has the
db_class_method decorator and db_wrapper method which are used for
cursor creation and calling the database method.
"""
import functools
import struct
from vtdb import database_context
from vtdb import dbexceptions
from vtdb import db_validator
from vtdb import keyrange
from vtdb import keyrange_constants
from vtdb import shard_constants
from vtdb import dbexceptions
from vtdb import sql_builder
from vtdb import vtgate_cursor
class __EmptyBindVariables(frozenset):
pass
pass
EmptyBindVariables = __EmptyBindVariables()
@ -35,6 +33,7 @@ class ShardRouting(object):
entity_id_sharding_key_map: this map is used for in clause queries.
shard_name: this is used to route queries for custom sharded keyspaces.
"""
def __init__(self, keyspace):
# keyspace of the table.
self.keyspace = keyspace
@ -51,9 +50,9 @@ def _is_iterable_container(x):
return hasattr(x, '__iter__')
INSERT_KW = "insert"
UPDATE_KW = "update"
DELETE_KW = "delete"
INSERT_KW = 'insert'
UPDATE_KW = 'update'
DELETE_KW = 'delete'
def is_dml(sql):
@ -103,8 +102,7 @@ def create_cursor_from_old_cursor(old_cursor, table_class):
def create_stream_cursor_from_cursor(original_cursor):
"""
This method creates streaming cursor from a regular cursor.
"""Creates streaming cursor from a regular cursor.
Args:
original_cursor: Cursor of VTGateCursor type
@ -114,7 +112,7 @@ def create_stream_cursor_from_cursor(original_cursor):
"""
if not isinstance(original_cursor, vtgate_cursor.VTGateCursor):
raise dbexceptions.ProgrammingError(
"Original cursor should be of VTGateCursor type.")
'Original cursor should be of VTGateCursor type.')
stream_cursor = vtgate_cursor.StreamVTGateCursor(
original_cursor._conn, original_cursor.keyspace,
original_cursor.tablet_type,
@ -130,13 +128,14 @@ def create_batch_cursor_from_cursor(original_cursor, writable=False):
Args:
original_cursor: Cursor of VTGateCursor type
writable: Bool flag.
Returns:
Returns BatchVTGateCursor that has same attributes as original_cursor.
"""
if not isinstance(original_cursor, vtgate_cursor.VTGateCursor):
raise dbexceptions.ProgrammingError(
"Original cursor should be of VTGateCursor type.")
'Original cursor should be of VTGateCursor type.')
batch_cursor = vtgate_cursor.BatchVTGateCursor(
original_cursor._conn,
original_cursor.tablet_type,
@ -145,8 +144,7 @@ def create_batch_cursor_from_cursor(original_cursor, writable=False):
def db_wrapper(method, **decorator_kwargs):
"""Decorator that is used to create the appropriate cursor
for the table and call the database method with it.
"""Decorator to call the database method with the appropriate cursor.
Args:
method: Method to decorate.
@ -158,7 +156,7 @@ def db_wrapper(method, **decorator_kwargs):
@functools.wraps(method)
def _db_wrapper(*pargs, **kwargs):
table_class = pargs[0]
write_method = kwargs.get("write_method", False)
write_method = kwargs.get('write_method', False)
if not issubclass(table_class, DBObjectBase):
raise dbexceptions.ProgrammingError(
"table class '%s' is not inherited from DBObjectBase" % table_class)
@ -171,10 +169,10 @@ def db_wrapper(method, **decorator_kwargs):
if write_method:
if not cursor.is_writable():
raise dbexceptions.ProgrammingError(
"Executing dmls on a non-writable cursor is not allowed.")
'Executing dmls on a non-writable cursor is not allowed.')
if table_class.is_mysql_view:
raise dbexceptions.ProgrammingError(
"writes disabled on view", table_class)
raise dbexceptions.ProgrammingError(
'writes disabled on view', table_class)
if pargs[2:]:
return method(table_class, cursor, *pargs[2:], **kwargs)
@ -185,7 +183,7 @@ def db_wrapper(method, **decorator_kwargs):
def write_db_class_method(*pargs, **kwargs):
"""Used for DML methods. Calls db_class_method."""
kwargs["write_method"] = True
kwargs['write_method'] = True
return db_class_method(*pargs, **kwargs)
@ -220,11 +218,9 @@ class DBObjectBase(object):
is_mysql_view = False
utf8_columns = None
@classmethod
def create_shard_routing(class_, *pargs, **kwargs):
"""This method is used to create ShardRouting object which is
used for determining routing attributes for the vtgate cursor.
"""Create ShardRouting that determines vtgate cursor routing attributes.
Returns:
ShardRouting object.
@ -232,24 +228,25 @@ class DBObjectBase(object):
raise NotImplementedError
@classmethod
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kwargs):
"""This creates the VTGateCursor object which is used to make
all the rpc calls to VTGate.
def create_vtgate_cursor(
class_, vtgate_conn, tablet_type, is_dml, **cursor_kwargs):
"""Creates the VTGateCursor which is used to make RPCs to VTGate.
Args:
vtgate_conn: connection to vtgate.
tablet_type: tablet type to connect to.
is_dml: Makes the cursor writable, enforces appropriate constraints.
vtgate_conn: connection to vtgate.
tablet_type: tablet type to connect to.
is_dml: Makes the cursor writable, enforces appropriate constraints.
**cursor_kwargs: More args.
Returns:
VTGateCursor for the query.
VTGateCursor for the query.
"""
raise NotImplementedError
@classmethod
def _validate_column_value_pairs_for_write(class_, **column_values):
invalid_columns = db_validator.invalid_utf8_columns(class_.utf8_columns or [],
column_values)
invalid_columns = db_validator.invalid_utf8_columns(
class_.utf8_columns or [], column_values)
if invalid_columns:
exc = InvalidUtf8DbWrite(class_.table_name, invalid_columns)
raise exc
@ -296,23 +293,24 @@ class DBObjectBase(object):
def create_select_query(class_, where_column_value_pairs, columns_list=None,
order_by=None, group_by=None, limit=None):
if class_.columns_list is None:
raise dbexceptions.ProgrammingError("DB class should define columns_list")
raise dbexceptions.ProgrammingError('DB class should define columns_list')
if columns_list is None:
columns_list = class_.columns_list
query, bind_vars = sql_builder.select_by_columns_query(columns_list,
class_.table_name,
where_column_value_pairs,
order_by=order_by,
group_by=group_by,
limit=limit)
query, bind_vars = sql_builder.select_by_columns_query(
columns_list,
class_.table_name,
where_column_value_pairs,
order_by=order_by,
group_by=group_by,
limit=limit)
return query, bind_vars
@write_db_class_method
def insert(class_, cursor, **bind_vars):
if class_.columns_list is None:
raise dbexceptions.ProgrammingError("DB class should define columns_list")
raise dbexceptions.ProgrammingError('DB class should define columns_list')
query, bind_vars = class_.create_insert_query(**bind_vars)
cursor.execute(query, bind_vars)
@ -330,19 +328,21 @@ class DBObjectBase(object):
@write_db_class_method
def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):
if not where_column_value_pairs:
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
raise dbexceptions.ProgrammingError(
'deleting the whole table is not allowed')
query, bind_vars = class_.create_delete_query(where_column_value_pairs,
limit=limit)
cursor.execute(query, bind_vars)
if cursor.rowcount == 0:
raise dbexceptions.DatabaseError("DB Row not found")
raise dbexceptions.DatabaseError('DB Row not found')
return cursor.rowcount
@db_class_method
def select_by_columns_streaming(class_, cursor, where_column_value_pairs,
columns_list=None, order_by=None, group_by=None,
limit=None, fetch_size=100):
def select_by_columns_streaming(
class_, cursor, where_column_value_pairs,
columns_list=None, order_by=None, group_by=None,
limit=None, fetch_size=100):
query, bind_vars = class_.create_select_query(where_column_value_pairs,
columns_list=columns_list,
order_by=order_by,
@ -382,7 +382,7 @@ class DBObjectBase(object):
@db_class_method
def get_min(class_, cursor):
if class_.id_column_name is None:
raise dbexceptions.ProgrammingError("id_column_name not set.")
raise dbexceptions.ProgrammingError('id_column_name not set.')
query, bind_vars = sql_builder.build_aggregate_query(
class_.table_name, class_.id_column_name, is_asc=True)
@ -392,7 +392,7 @@ class DBObjectBase(object):
@db_class_method
def get_max(class_, cursor):
if class_.id_column_name is None:
raise dbexceptions.ProgrammingError("id_column_name not set.")
raise dbexceptions.ProgrammingError('id_column_name not set.')
query, bind_vars = sql_builder.build_aggregate_query(
class_.table_name, class_.id_column_name, is_asc=False)

Просмотреть файл

@ -13,9 +13,10 @@ from vtdb import vtgate_cursor
class DBObjectCustomSharded(db_object.DBObjectBase):
"""Base class for custom-sharded db classes.
This class is intended to support a custom sharding scheme, where the user
controls the routing of their queries by passing in the shard_name
explicitly.This provides helper methods for common database access operations.
This class is intended to support a custom sharding scheme, where
the user controls the routing of their queries by passing in the
shard_name explicitly. This provides helper methods for common
database access operations.
"""
keyspace = None
sharding = shard_constants.CUSTOM_SHARDED
@ -25,19 +26,21 @@ class DBObjectCustomSharded(db_object.DBObjectBase):
@classmethod
def create_shard_routing(class_, *pargs, **kwargs):
routing = db_object.ShardRouting(keyspace)
routing.shard_name = kargs.get('shard_name')
routing = db_object.ShardRouting(class_.keyspace)
routing.shard_name = kwargs.get('shard_name')
if routing.shard_name is None:
dbexceptions.InternalError("For custom sharding, shard_name cannot be None.")
dbexceptions.InternalError(
'For custom sharding, shard_name cannot be None.')
if (_is_iterable_container(routing.shard_name)
and is_dml):
raise dbexceptions.InternalError(
"Writes are not allowed on multiple shards.")
'Writes are not allowed on multiple shards.')
return routing
@classmethod
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
def create_vtgate_cursor(
class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
# FIXME:extend VTGateCursor's api to accept shard_names
# and allow queries based on that.
routing = class_.create_shard_routing(**cursor_kargs)

Просмотреть файл

@ -5,26 +5,17 @@ relevant methods. LookupDBObject inherits from DBObjectUnsharded and
extends the functionality for getting, creating, updating and deleting
the lookup relationship.
"""
import functools
import struct
from vtdb import db_object
from vtdb import db_object_unsharded
from vtdb import dbexceptions
from vtdb import keyrange
from vtdb import keyrange_constants
from vtdb import shard_constants
from vtdb import vtgate_cursor
class LookupDBObject(db_object_unsharded.DBObjectUnsharded):
"""This is an example implementation of lookup class where it is stored
in unsharded db.
"""
"""A example implementation of a lookup class stored in an unsharded db."""
@classmethod
def get(class_, cursor, entity_id_column, entity_id):
where_column_value_pairs = [(entity_id_column, entity_id),]
rows = class_.select_by_columns(cursor, where_column_value_pairs)
rows = class_.select_by_columns(cursor, where_column_value_pairs)
return [row.__dict__ for row in rows]
@classmethod
@ -35,7 +26,7 @@ class LookupDBObject(db_object_unsharded.DBObjectUnsharded):
def update(class_, cursor, sharding_key_column_name, sharding_key,
entity_id_column, new_entity_id):
where_column_value_pairs = [(sharding_key_column_name, sharding_key),]
update_column_value_pairs = [(entity_id_column,new_entity_id),]
update_column_value_pairs = [(entity_id_column, new_entity_id),]
return class_.update_columns(cursor, where_column_value_pairs,
update_column_value_pairs)

Просмотреть файл

@ -1,12 +1,14 @@
"""Module containing base classes for range-sharded database objects.
There are two base classes for tables that live in range-sharded keyspace -
1. DBObjectRangeSharded - This should be used for tables that only reference lookup entities
but don't create or manage them. Please see examples in test/clientlib_tests/db_class_sharded.py.
2. DBObjectEntityRangeSharded - This inherits from DBObjectRangeSharded and is used for tables
and also create new lookup relationships.
This module also contains helper methods for cursor creation for accessing lookup tables
and methods for dml and select for the above mentioned base classes.
1. DBObjectRangeSharded - This should be used for tables that only reference
lookup entities but don't create or manage them. Please see examples in
test/clientlib_tests/db_class_sharded.py.
2. DBObjectEntityRangeSharded - This inherits from DBObjectRangeSharded and
is used for tables and also create new lookup relationships.
This module also contains helper methods for cursor creation for
accessing lookup tables and methods for dml and select for the above
mentioned base classes.
"""
import functools
import struct
@ -14,7 +16,6 @@ import struct
from vtdb import db_object
from vtdb import dbexceptions
from vtdb import keyrange
from vtdb import keyrange_constants
from vtdb import shard_constants
from vtdb import sql_builder
from vtdb import vtgate_cursor
@ -29,7 +30,7 @@ pack_keyspace_id = struct.Struct('!Q').pack
# This unpacks the keyspace_id so that it can be used
# in bind variables.
def unpack_keyspace_id(kid):
return struct.Struct('!Q').unpack(kid)[0]
return struct.Struct('!Q').unpack(kid)[0]
class DBObjectRangeSharded(db_object.DBObjectBase):
@ -51,7 +52,7 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
# List of columns on the database table. This is used in query construction.
columns_list = None
#FIXME: is this needed ?
# FIXME: is this needed ?
id_column_name = None
# sharding_key_column_name defines column name for sharding key for this
@ -70,7 +71,7 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
column_lookup_name_map = None
@classmethod
def create_shard_routing(class_, *pargs, **kargs):
def create_shard_routing(class_, *pargs, **kargs):
"""This creates the ShardRouting object based on the kargs.
This prunes the routing kargs so as not to interfere with the
actual database method.
@ -79,9 +80,10 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
*pargs: Positional arguments
**kargs: Routing key-value params. These are used to determine routing.
There are two mutually exclusive mechanisms to indicate routing.
1. entity_id_map {"entity_id_column": entity_id_value} where entity_id_column
could be the sharding key or a lookup based entity column of this table. This
helps determine the keyspace_ids for the cursor.
1. entity_id_map {"entity_id_column": entity_id_value} where
entity_id_column could be the sharding key or a lookup based entity
column of this table. This helps determine the keyspace_ids for the
cursor.
2. keyrange - This helps determine the keyrange for the cursor.
Returns:
@ -91,10 +93,10 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
routing = db_object.ShardRouting(class_.keyspace)
entity_id_map = None
entity_id_map = kargs.get("entity_id_map", None)
entity_id_map = kargs.get('entity_id_map', None)
if entity_id_map is None:
kr = None
key_range = kargs.get("keyrange", None)
key_range = kargs.get('keyrange', None)
if isinstance(key_range, keyrange.KeyRange):
kr = key_range
else:
@ -106,36 +108,42 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
# entity_id_map is not None
if len(entity_id_map) != 1:
dbexceptions.ProgrammingError("Invalid entity_id_map '%s'" % entity_id_map)
dbexceptions.ProgrammingError(
"Invalid entity_id_map '%s'" % entity_id_map)
entity_id_col = entity_id_map.keys()[0]
entity_id = entity_id_map[entity_id_col]
#TODO: the current api means that if a table doesn't have the sharding key column name
# then it cannot pass in sharding key for routing purposes. Will this cause
# extra load on lookup db/cache ? This is cleaner from a design perspective.
#TODO: the current api means that if a table doesn't have the
# sharding key column name then it cannot pass in sharding key for
# routing purposes. Will this cause extra load on lookup db/cache
# ? This is cleaner from a design perspective.
if entity_id_col == class_.sharding_key_column_name:
# Routing using sharding key.
routing.sharding_key = entity_id
if not class_.is_sharding_key_valid(routing.sharding_key):
raise dbexceptions.InternalError("Invalid sharding_key %s" % routing.sharding_key)
raise dbexceptions.InternalError(
'Invalid sharding_key %s' % routing.sharding_key)
else:
# Routing using lookup based entity.
routing.entity_column_name = entity_id_col
routing.entity_id_sharding_key_map = class_.lookup_sharding_key_from_entity_id(
lookup_cursor_method, entity_id_col, entity_id)
routing.entity_id_sharding_key_map = (
class_.lookup_sharding_key_from_entity_id(
lookup_cursor_method, entity_id_col, entity_id))
return routing
@classmethod
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
def create_vtgate_cursor(
class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
cursor_method = functools.partial(db_object.create_cursor_from_params,
vtgate_conn, tablet_type, False)
routing = class_.create_shard_routing(cursor_method, **cursor_kargs)
if is_dml:
if routing.sharding_key is None or db_object._is_iterable_container(routing.sharding_key):
if (routing.sharding_key is None or
db_object._is_iterable_container(routing.sharding_key)):
dbexceptions.InternalError(
"Writes require unique sharding_key")
'Writes require unique sharding_key')
keyspace_ids = None
keyranges = None
@ -151,7 +159,8 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
elif routing.entity_id_sharding_key_map is not None:
keyspace_ids = []
for sharding_key in routing.entity_id_sharding_key_map.values():
keyspace_ids.append(pack_keyspace_id(class_.sharding_key_to_keyspace_id(sharding_key)))
keyspace_ids.append(
pack_keyspace_id(class_.sharding_key_to_keyspace_id(sharding_key)))
elif routing.keyrange:
keyranges = [routing.keyrange,]
@ -174,11 +183,14 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
return class_.column_lookup_name_map.get(column_name, column_name)
@classmethod
def lookup_sharding_key_from_entity_id(class_, cursor_method, entity_id_column, entity_id):
def lookup_sharding_key_from_entity_id(
class_, cursor_method, entity_id_column, entity_id):
"""This method is used to map any entity id to sharding key.
Args:
entity_id_column: Non-sharding key indexes that can be used for query routing.
cursor_method: Cursor method.
entity_id_column: Non-sharding key indexes that can be used for query
routing.
entity_id: entity id value.
Returns:
@ -189,20 +201,22 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
rows = lookup_class.get(cursor_method, entity_lookup_column, entity_id)
entity_id_sharding_key_map = {}
if len(rows) == 0:
#return entity_id_sharding_key_map
raise dbexceptions.DatabaseError("LookupRow not found")
if not rows:
# return entity_id_sharding_key_map
raise dbexceptions.DatabaseError('LookupRow not found')
if class_.sharding_key_column_name is not None:
sk_lookup_column = class_.get_lookup_column_name(class_.sharding_key_column_name)
sk_lookup_column = class_.get_lookup_column_name(
class_.sharding_key_column_name)
else:
# This is needed since the table may not have a sharding key column name
# but the lookup map will have it.
lookup_column_names = rows[0].keys()
if len(lookup_column_names) != 2:
raise dbexceptions.ProgrammingError(
"lookup table has more than two columns.")
sk_lookup_column = list(set(lookup_column_names) - set(list(entity_lookup_column)))[0]
'lookup table has more than two columns.')
sk_lookup_column = list(
set(lookup_column_names) - set(list(entity_lookup_column)))[0]
for row in rows:
en_id = row[entity_lookup_column]
sk = row[sk_lookup_column]
@ -212,8 +226,8 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
@db_object.db_class_method
def select_by_ids(class_, cursor, where_column_value_pairs,
columns_list=None, order_by=None, group_by=None,
limit=None, **kwargs):
columns_list=None, order_by=None, group_by=None,
limit=None, **kwargs):
"""This method is used to perform in-clause queries.
Such queries can cause vtgate to scatter over multiple shards.
@ -238,17 +252,20 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
entity_col_name = class_.sharding_key_column_name
if db_object._is_iterable_container(cursor.routing.sharding_key):
for sk in list(cursor.routing.sharding_key):
entity_id_keyspace_id_map[sk] = pack_keyspace_id(class_.sharding_key_to_keyspace_id(sk))
entity_id_keyspace_id_map[sk] = pack_keyspace_id(
class_.sharding_key_to_keyspace_id(sk))
else:
sk = cursor.routing.sharding_key
entity_id_keyspace_id_map[sk] = pack_keyspace_id(class_.sharding_key_to_keyspace_id(sk))
entity_id_keyspace_id_map[sk] = pack_keyspace_id(
class_.sharding_key_to_keyspace_id(sk))
elif cursor.routing.entity_id_sharding_key_map is not None:
# If the in-clause is based on entity column
entity_col_name = cursor.routing.entity_column_name
for en_id, sk in cursor.routing.entity_id_sharding_key_map.iteritems():
entity_id_keyspace_id_map[en_id] = pack_keyspace_id(class_.sharding_key_to_keyspace_id(sk))
entity_id_keyspace_id_map[en_id] = pack_keyspace_id(
class_.sharding_key_to_keyspace_id(sk))
else:
dbexceptions.ProgrammingError("Invalid routing method used.")
dbexceptions.ProgrammingError('Invalid routing method used.')
# cursor.routing.entity_column_name is set while creating shard routing.
rowcount = cursor.execute_entity_ids(query, bind_vars,
@ -284,7 +301,7 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
@db_object.db_class_method
def insert(class_, cursor, **bind_vars):
if class_.columns_list is None:
raise dbexceptions.ProgrammingError("DB class should define columns_list")
raise dbexceptions.ProgrammingError('DB class should define columns_list')
keyspace_id = bind_vars.get('keyspace_id', None)
if keyspace_id is None:
@ -331,35 +348,36 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):
if not where_column_value_pairs:
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
raise dbexceptions.ProgrammingError(
'deleting the whole table is not allowed')
where_column_value_pairs = class_._add_keyspace_id(
unpack_keyspace_id(cursor.keyspace_ids[0]), where_column_value_pairs)
query, bind_vars = sql_builder.delete_by_columns_query(class_.table_name,
where_column_value_pairs,
limit=limit)
query, bind_vars = sql_builder.delete_by_columns_query(
class_.table_name, where_column_value_pairs, limit=limit)
cursor.execute(query, bind_vars)
if cursor.rowcount == 0:
raise dbexceptions.DatabaseError("DB Row not found")
raise dbexceptions.DatabaseError('DB Row not found')
return cursor.rowcount
class DBObjectEntityRangeSharded(DBObjectRangeSharded):
"""Base class for sharded tables that also needs to create and manage lookup
entities.
"""Base class for sharded tables that create and manage lookup entities.
This provides default implementation of routing helper methods, cursor
creation and common database access operations.
"""
@classmethod
def get_insert_id_from_lookup(class_, cursor_method, entity_id_col, **bind_vars):
def get_insert_id_from_lookup(
class_, cursor_method, entity_id_col, **bind_vars):
"""This method is used to map any entity id to sharding key.
Args:
entity_id_column: Non-sharding key indexes that can be used for query routing.
entity_id_column: Non-sharding key indexes that can be used for query
routing.
entity_id: entity id value.
Returns:
@ -376,18 +394,19 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
@classmethod
def delete_sharding_key_entity_id_lookup(class_, cursor_method,
sharding_key):
sharding_key_lookup_column = class_.get_lookup_column_name(class_.sharding_key_column_name)
sharding_key_lookup_column = class_.get_lookup_column_name(
class_.sharding_key_column_name)
for lookup_class in class_.entity_id_lookup_map.values():
lookup_class.delete(cursor_method,
sharding_key_lookup_column,
sharding_key)
@classmethod
def update_sharding_key_entity_id_lookup(class_, cursor_method,
sharding_key, entity_id_column,
new_entity_id):
sharding_key_lookup_column = class_.get_lookup_column_name(class_.sharding_key_column_name)
sharding_key_lookup_column = class_.get_lookup_column_name(
class_.sharding_key_column_name)
entity_id_lookup_column = class_.get_lookup_column_name(entity_id_column)
lookup_class = class_.entity_id_lookup_map[entity_id_column]
return lookup_class.update(cursor_method,
@ -400,28 +419,29 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
@db_object.write_db_class_method
def insert_primary(class_, cursor, **bind_vars):
if class_.columns_list is None:
raise dbexceptions.ProgrammingError("DB class should define columns_list")
raise dbexceptions.ProgrammingError('DB class should define columns_list')
query, bind_vars = class_.create_insert_query(**bind_vars)
cursor.execute(query, bind_vars)
return cursor.lastrowid
@classmethod
def insert(class_, cursor_method, **bind_vars):
""" This method creates the lookup relationship as well as the insert
in the primary table. The creation of the lookup entry also creates the
primary key for the row in the primary table.
"""Creates the lookup relationship and inserts in the primary table.
The lookup relationship is determined by class_.column_lookup_name_map and the bind
variables passed in. There are two types of entities -
1. Table for which the entity that is also the primary sharding key for this keyspace.
2. Entity table that creates a new entity and needs to create a lookup between
that entity and sharding key.
The creation of the lookup entry also creates the primary key for
the row in the primary table.
The lookup relationship is determined by class_.column_lookup_name_map
and the bind variables passed in. There are two types of entities -
1. Table for which the entity that is also the primary sharding key for
this keyspace.
2. Entity table that creates a new entity and needs to create a lookup
between that entity and sharding key.
"""
if class_.sharding_key_column_name is None:
raise dbexceptions.ProgrammingError(
"sharding_key_column_name empty for DBObjectEntityRangeSharded")
'sharding_key_column_name empty for DBObjectEntityRangeSharded')
# Used for insert into class_.table_name
new_inserted_key = None
@ -431,7 +451,7 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
if (not class_.entity_id_lookup_map
or not isinstance(class_.entity_id_lookup_map, dict)):
raise dbexceptions.ProgrammingError(
"Invalid entity_id_lookup_map %s" % class_.entity_id_lookup_map)
'Invalid entity_id_lookup_map %s' % class_.entity_id_lookup_map)
entity_col = class_.entity_id_lookup_map.keys()[0]
# Create the lookup entry first
@ -472,7 +492,7 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
update_column_value_pairs):
sharding_key = cursor.routing.sharding_key
if sharding_key is None:
raise dbexceptions.ProgrammingError("sharding_key cannot be empty")
raise dbexceptions.ProgrammingError('sharding_key cannot be empty')
# update the primary table first.
query, bind_vars = class_.create_update_query(
@ -498,23 +518,26 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
limit=None):
sharding_key = cursor.routing.sharding_key
if sharding_key is None:
raise dbexceptions.ProgrammingError("sharding_key cannot be empty")
raise dbexceptions.ProgrammingError('sharding_key cannot be empty')
if not where_column_value_pairs:
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
raise dbexceptions.ProgrammingError(
'deleting the whole table is not allowed')
query, bind_vars = sql_builder.delete_by_columns_query(class_.table_name,
where_column_value_pairs,
limit=limit)
query, bind_vars = sql_builder.delete_by_columns_query(
class_.table_name,
where_column_value_pairs,
limit=limit)
cursor.execute(query, bind_vars)
if cursor.rowcount == 0:
raise dbexceptions.DatabaseError("DB Row not found")
raise dbexceptions.DatabaseError('DB Row not found')
rowcount = cursor.rowcount
#delete the lookup map.
# delete the lookup map.
lookup_cursor_method = functools.partial(
db_object.create_cursor_from_old_cursor, cursor)
class_.delete_sharding_key_entity_id_lookup(lookup_cursor_method, sharding_key)
class_.delete_sharding_key_entity_id_lookup(
lookup_cursor_method, sharding_key)
return rowcount

Просмотреть файл

@ -4,9 +4,6 @@ DBObjectUnsharded inherits from DBObjectBase, the implementation
for the common database operations is defined in DBObjectBase.
DBObjectUnsharded defines the cursor creation methods for the same.
"""
import functools
import struct
from vtdb import db_object
from vtdb import dbexceptions
from vtdb import keyrange
@ -31,11 +28,13 @@ class DBObjectUnsharded(db_object.DBObjectBase):
@classmethod
def create_shard_routing(class_, *pargs, **kwargs):
routing = db_object.ShardRouting(class_.keyspace)
routing.keyrange = keyrange.KeyRange(keyrange_constants.NON_PARTIAL_KEYRANGE)
routing.keyrange = keyrange.KeyRange(
keyrange_constants.NON_PARTIAL_KEYRANGE)
return routing
@classmethod
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
def create_vtgate_cursor(
class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
routing = class_.create_shard_routing(**cursor_kargs)
if routing.keyrange is not None:
keyranges = [routing.keyrange,]

Просмотреть файл

@ -1,10 +1,12 @@
from vtdb import dbexceptions
# A simple class to trap and re-export only variables referenced from
# the sql statement since bind dictionaries can be *very* noisy. This
# is a by-product of converting the DB-API %(name)s syntax to our
# :name syntax.
class BindVarsProxy(object):
def __init__(self, bind_vars):
self.bind_vars = bind_vars
self.accessed_keys = set()
@ -27,7 +29,7 @@ class BindVarsProxy(object):
def prepare_query_bind_vars(query, bind_vars):
bind_vars_proxy = BindVarsProxy(bind_vars)
try:
query = query % bind_vars_proxy
query %= bind_vars_proxy
except KeyError as e:
raise dbexceptions.InterfaceError(e[0], query, bind_vars)

Просмотреть файл

@ -1,41 +1,53 @@
import exceptions
class Error(exceptions.StandardError):
pass
class DatabaseError(exceptions.StandardError):
pass
class DataError(DatabaseError):
pass
class Warning(exceptions.StandardError):
pass
class InterfaceError(Error):
pass
class InternalError(DatabaseError):
pass
class OperationalError(DatabaseError):
pass
class ProgrammingError(DatabaseError):
pass
class NotSupportedError(ProgrammingError):
pass
class IntegrityError(DatabaseError):
pass
class PartialCommitError(IntegrityError):
pass
# Below errors are VT specific
# Retry means a simple and immediate reconnect to the same host/port
# will likely fix things. This is initiated by a graceful restart on
# the server side. In general this can be handled transparently
@ -51,8 +63,8 @@ class FatalError(OperationalError):
pass
# This failure is operational in the sense that we must teardown the connection to
# ensure future RPCs are handled correctly.
# This failure is operational in the sense that we must teardown the
# connection to ensure future RPCs are handled correctly.
class TimeoutError(OperationalError):
pass
@ -68,4 +80,3 @@ class RequestBacklog(DatabaseError):
# ThrottledError is raised when client exceeds allocated quota on the server
class ThrottledError(DatabaseError):
pass

Просмотреть файл

@ -32,14 +32,12 @@ VT_VAR_STRING = 253
VT_STRING = 254
VT_GEOMETRY = 255
# FIXME(msolomon) intended for MySQL emulation, but seems more dangerous
# to keep this around. This doesn't seem to even be used right now.
def Binary(x):
return array('c', x)
class DBAPITypeObject:
class DBAPITypeObject(object):
def __init__(self, *values):
self.values = values
def __cmp__(self, other):
if other in self.values:
return 0
@ -47,30 +45,34 @@ class DBAPITypeObject:
# FIXME(msolomon) why do we have these values if they aren't referenced?
STRING = DBAPITypeObject(VT_ENUM, VT_VAR_STRING, VT_STRING)
BINARY = DBAPITypeObject(VT_TINY_BLOB, VT_MEDIUM_BLOB, VT_LONG_BLOB, VT_BLOB)
NUMBER = DBAPITypeObject(VT_DECIMAL, VT_TINY, VT_SHORT, VT_LONG, VT_FLOAT, VT_DOUBLE, VT_LONGLONG, VT_INT24, VT_YEAR, VT_NEWDECIMAL)
DATETIME = DBAPITypeObject(VT_TIMESTAMP, VT_DATE, VT_TIME, VT_DATETIME, VT_NEWDATE)
ROWID = DBAPITypeObject()
STRING = DBAPITypeObject(VT_ENUM, VT_VAR_STRING, VT_STRING)
BINARY = DBAPITypeObject(VT_TINY_BLOB, VT_MEDIUM_BLOB, VT_LONG_BLOB, VT_BLOB)
NUMBER = DBAPITypeObject(
VT_DECIMAL, VT_TINY, VT_SHORT, VT_LONG, VT_FLOAT, VT_DOUBLE, VT_LONGLONG,
VT_INT24, VT_YEAR, VT_NEWDECIMAL)
DATETIME = DBAPITypeObject(
VT_TIMESTAMP, VT_DATE, VT_TIME, VT_DATETIME, VT_NEWDATE)
ROWID = DBAPITypeObject()
conversions = {
VT_DECIMAL : Decimal,
VT_TINY : int,
VT_SHORT : int,
VT_LONG : long,
VT_FLOAT : float,
VT_DOUBLE : float,
VT_TIMESTAMP : times.DateTimeOrNone,
VT_LONGLONG : long,
VT_INT24 : int,
VT_DATE : times.DateOrNone,
VT_TIME : times.TimeDeltaOrNone,
VT_DATETIME : times.DateTimeOrNone,
VT_YEAR : int,
VT_NEWDATE : times.DateOrNone,
VT_NEWDECIMAL : Decimal,
VT_DECIMAL: Decimal,
VT_TINY: int,
VT_SHORT: int,
VT_LONG: long,
VT_FLOAT: float,
VT_DOUBLE: float,
VT_TIMESTAMP: times.DateTimeOrNone,
VT_LONGLONG: long,
VT_INT24: int,
VT_DATE: times.DateOrNone,
VT_TIME: times.TimeDeltaOrNone,
VT_DATETIME: times.DateTimeOrNone,
VT_YEAR: int,
VT_NEWDATE: times.DateOrNone,
VT_NEWDECIMAL: Decimal,
}
# This is a temporary workaround till we figure out how to support
# native lists in our API.
class List(list):
@ -82,6 +84,7 @@ NoneType = type(None)
# That doesn't seem dramatically better than __sql_literal__ but it might
# be move self-documenting.
def convert_bind_vars(bind_variables):
new_vars = {}
if bind_variables is None:

Просмотреть файл

@ -1,12 +1,12 @@
# Copyright 2015, Google Inc. All rights reserved.
# Use of this source code is governed by a BSD-style license that can
# be found in the LICENSE file.
from itertools import izip
import logging
from net import gorpc
from net import bsonrpc
from net import gorpc
from vtdb import dbexceptions
from vtdb import field_types
from vtdb import update_stream
@ -26,8 +26,8 @@ def _make_row(row, conversions):
class GoRpcUpdateStreamConnection(update_stream.UpdateStreamConnection):
"""GoRpcUpdateStreamConnection is the go rpc implementation of
UpdateStreamConnection.
"""The go rpc implementation of UpdateStreamConnection.
It is registered as 'gorpc' protocol.
"""
@ -61,7 +61,7 @@ class GoRpcUpdateStreamConnection(update_stream.UpdateStreamConnection):
"""Note this implementation doesn't honor the timeout."""
try:
self.client.stream_call('UpdateStream.ServeUpdateStream',
{"Position": position})
{'Position': position})
while True:
response = self.client.stream_next()
if response is None:

Просмотреть файл

@ -3,16 +3,14 @@
# be found in the LICENSE file.
from itertools import izip
import logging
from urlparse import urlparse
from vtdb import dbexceptions
from vtdb import field_types
from vtdb import update_stream
from vtproto import binlogdata_pb2
from vtproto import binlogservice_pb2
from vtproto import replicationdata_pb2
from vtdb import field_types
from vtdb import update_stream
def _make_row(row, conversions):
converted_row = []
@ -28,8 +26,8 @@ def _make_row(row, conversions):
class GRPCUpdateStreamConnection(update_stream.UpdateStreamConnection):
"""GRPCUpdateStreamConnection is the gRPC implementation of
UpdateStreamConnection.
"""The gRPC implementation of UpdateStreamConnection.
It is registered as 'grpc' protocol.
"""
@ -50,7 +48,7 @@ class GRPCUpdateStreamConnection(update_stream.UpdateStreamConnection):
self.stub = None
def is_closed(self):
return self.stub == None
return self.stub is None
def stream_update(self, position, timeout=3600.0):
req = binlogdata_pb2.StreamUpdateRequest(position=position)
@ -72,13 +70,14 @@ class GRPCUpdateStreamConnection(update_stream.UpdateStreamConnection):
rows.append(row)
try:
yield update_stream.StreamEvent(category=int(stream_event.category),
table_name=stream_event.table_name,
fields=fields,
rows=rows,
sql=stream_event.sql,
timestamp=stream_event.timestamp,
transaction_id=stream_event.transaction_id)
yield update_stream.StreamEvent(
category=int(stream_event.category),
table_name=stream_event.table_name,
fields=fields,
rows=rows,
sql=stream_event.sql,
timestamp=stream_event.timestamp,
transaction_id=stream_event.transaction_id)
except GeneratorExit:
# if the loop is interrupted for any reason, we need to
# cancel the iterator, so we close the RPC connection,

Просмотреть файл

@ -34,17 +34,20 @@ class KeyRange(codec.BSONCoding):
else:
kr = kr.split('-')
if not isinstance(kr, tuple) and not isinstance(kr, list) or len(kr) != 2:
raise dbexceptions.ProgrammingError("keyrange must be a list or tuple or a '-' separated str %s" % kr)
raise dbexceptions.ProgrammingError(
'keyrange must be a list or tuple or a '-' separated str %s' % kr)
self.Start = kr[0].strip().decode('hex')
self.End = kr[1].strip().decode('hex')
def __str__(self):
if self.Start == keyrange_constants.MIN_KEY and self.End == keyrange_constants.MAX_KEY:
if (self.Start == keyrange_constants.MIN_KEY and
self.End == keyrange_constants.MAX_KEY):
return keyrange_constants.NON_PARTIAL_KEYRANGE
return '%s-%s' % (self.Start.encode('hex'), self.End.encode('hex'))
def __repr__(self):
if self.Start == keyrange_constants.MIN_KEY and self.End == keyrange_constants.MAX_KEY:
if (self.Start == keyrange_constants.MIN_KEY and
self.End == keyrange_constants.MAX_KEY):
return 'KeyRange(%r)' % keyrange_constants.NON_PARTIAL_KEYRANGE
return 'KeyRange(%r-%r)' % (self.Start, self.End)
@ -56,7 +59,7 @@ class KeyRange(codec.BSONCoding):
{"Start": start, "End": end}
"""
return {"Start": self.Start, "End": self.End}
return {'Start': self.Start, 'End': self.End}
def bson_init(self, raw_values):
"""Bson initialize the object with start and end dict.
@ -64,5 +67,5 @@ class KeyRange(codec.BSONCoding):
Args:
raw_values: Dictionary of start and end values for keyrange.
"""
self.Start = raw_values["Start"]
self.End = raw_values["End"]
self.Start = raw_values['Start']
self.End = raw_values['End']

Просмотреть файл

@ -4,14 +4,14 @@
# This is the shard name for when the keyrange covers the entire space
# for unsharded database.
SHARD_ZERO = "0"
SHARD_ZERO = '0'
# Keyrange that spans the entire space, used
# for unsharded database.
NON_PARTIAL_KEYRANGE = ""
NON_PARTIAL_KEYRANGE = ''
MIN_KEY = ''
MAX_KEY = ''
KIT_UNSET = ""
KIT_UINT64 = "uint64"
KIT_BYTES = "bytes"
KIT_UNSET = ''
KIT_UINT64 = 'uint64'
KIT_BYTES = 'bytes'

Просмотреть файл

@ -10,9 +10,12 @@ from vtdb import keyrange_constants
pack_keyspace_id = struct.Struct('!Q').pack
# Represent the SrvKeyspace object from the toposerver, and provide functions
# to extract sharding information from the same.
class Keyspace(object):
"""Represent the SrvKeyspace object from the toposerver.
Provide functions to extract sharding information from the same.
"""
name = None
partitions = None
sharding_col_name = None
@ -23,8 +26,9 @@ class Keyspace(object):
def __init__(self, name, data):
self.name = name
self.partitions = data.get('Partitions', {})
self.sharding_col_name = data.get('ShardingColumnName', "")
self.sharding_col_type = data.get('ShardingColumnType', keyrange_constants.KIT_UNSET)
self.sharding_col_name = data.get('ShardingColumnName', '')
self.sharding_col_type = data.get(
'ShardingColumnType', keyrange_constants.KIT_UNSET)
self.served_from = data.get('ServedFrom', None)
def get_shards(self, db_type):
@ -51,6 +55,12 @@ class Keyspace(object):
"""Finds the shard for a keyspace_id.
WARNING: this only works for KIT_UINT64 keyspace ids.
Returns:
Shard name.
Raises:
ValueError on invalid keyspace_id.
"""
if not keyspace_id:
raise ValueError('keyspace_id is not set')
@ -64,11 +74,12 @@ class Keyspace(object):
shard['KeyRange']['Start'],
shard['KeyRange']['End']):
return shard['Name']
raise ValueError('cannot find shard for keyspace_id %s in %s' % (keyspace_id, shards))
raise ValueError(
'cannot find shard for keyspace_id %s in %s' % (keyspace_id, shards))
def _shard_contain_kid(pkid, start, end):
return start <= pkid and (end == keyrange_constants.MAX_KEY or pkid < end)
return start <= pkid and (end == keyrange_constants.MAX_KEY or pkid < end)
def read_keyspace(topo_client, keyspace_name):

Просмотреть файл

@ -4,9 +4,9 @@ Different sharding schemes govern different routing strategies
that are computed while create the correct cursor.
"""
UNSHARDED = "UNSHARDED"
RANGE_SHARDED = "RANGE"
CUSTOM_SHARDED = "CUSTOM"
UNSHARDED = 'UNSHARDED'
RANGE_SHARDED = 'RANGE'
CUSTOM_SHARDED = 'CUSTOM'
TABLET_TYPE_MASTER = 'master'
TABLET_TYPE_REPLICA = 'replica'

Просмотреть файл

@ -13,7 +13,7 @@ from vtdb import field_types
from vtdb import vtdb_logger
_errno_pattern = re.compile('\(errno (\d+)\)')
_errno_pattern = re.compile(r'\(errno (\d+)\)')
def handle_app_error(exc_args):
@ -60,10 +60,13 @@ def convert_exception(exc, *args):
return exc
# A simple, direct connection to the vttablet query server.
# This is shard-unaware and only handles the most basic communication.
# If something goes wrong, this object should be thrown away and a new one instantiated.
class TabletConnection(object):
"""A simple, direct connection to the vttablet query server.
This is shard-unaware and only handles the most basic communication.
If something goes wrong, this object should be thrown away and a new
one instantiated.
"""
transaction_id = 0
session_id = 0
_stream_fields = None
@ -71,7 +74,9 @@ class TabletConnection(object):
_stream_result = None
_stream_result_index = None
def __init__(self, addr, tablet_type, keyspace, shard, timeout, user=None, password=None, keyfile=None, certfile=None):
def __init__(
self, addr, tablet_type, keyspace, shard, timeout, user=None,
password=None, keyfile=None, certfile=None):
self.addr = addr
self.tablet_type = tablet_type
self.keyspace = keyspace
@ -82,7 +87,8 @@ class TabletConnection(object):
self.logger_object = vtdb_logger.get_logger()
def __str__(self):
return '<TabletConnection %s %s %s/%s>' % (self.addr, self.tablet_type, self.keyspace, self.shard)
return '<TabletConnection %s %s %s/%s>' % (
self.addr, self.tablet_type, self.keyspace, self.shard)
def dial(self):
try:
@ -92,11 +98,13 @@ class TabletConnection(object):
# redial will succeed. This is more a hint that you are doing
# it wrong and misunderstanding the life cycle of a
# TabletConnection.
#raise dbexceptions.ProgrammingError('attempting to reuse TabletConnection')
# raise dbexceptions.ProgrammingError(
# 'attempting to reuse TabletConnection')
self.client.dial()
params = {'Keyspace': self.keyspace, 'Shard': self.shard}
response = self.rpc_call_and_extract_error('SqlQuery.GetSessionId', params)
response = self.rpc_call_and_extract_error(
'SqlQuery.GetSessionId', params)
self.session_id = response.reply['SessionId']
except gorpc.GoRpcError as e:
raise convert_exception(e, str(self))
@ -163,11 +171,14 @@ class TabletConnection(object):
raise convert_exception(e, str(self))
def rpc_call_and_extract_error(self, method_name, request):
"""Makes an RPC call, and extracts any app error that's embedded in the reply.
"""Makes an RPC, extracts any app error that's embedded in the reply.
Args:
method_name - RPC method name, as a string, to call
request - request to send to the RPC method call
method_name: RPC method name, as a string, to call.
request: Request to send to the RPC method call.
Returns:
Response from RPC.
Raises:
gorpc.AppError if there is an app error embedded in the reply
@ -272,11 +283,13 @@ class TabletConnection(object):
reply = first_response.reply
if reply.get('Err'):
self.__drain_conn_after_streaming_app_error()
raise gorpc.AppError(reply['Err'].get('Message', 'Missing error message'))
raise gorpc.AppError(reply['Err'].get(
'Message', 'Missing error message'))
for field in reply['Fields']:
self._stream_fields.append((field['Name'], field['Type']))
self._stream_conversions.append(field_types.conversions.get(field['Type']))
self._stream_conversions.append(
field_types.conversions.get(field['Type']))
except gorpc.GoRpcError as e:
self.logger_object.log_private_data(bind_variables)
raise convert_exception(e, str(self), sql)
@ -305,11 +318,13 @@ class TabletConnection(object):
reply = first_response.reply
if reply.get('Err'):
self.__drain_conn_after_streaming_app_error()
raise gorpc.AppError(reply['Err'].get('Message', 'Missing error message'))
raise gorpc.AppError(reply['Err'].get(
'Message', 'Missing error message'))
for field in reply['Fields']:
self._stream_fields.append((field['Name'], field['Type']))
self._stream_conversions.append(field_types.conversions.get(field['Type']))
self._stream_conversions.append(
field_types.conversions.get(field['Type']))
except gorpc.GoRpcError as e:
self.logger_object.log_private_data(bind_variables)
raise convert_exception(e, str(self), sql)
@ -324,7 +339,7 @@ class TabletConnection(object):
return None
# See if we need to read more or whether we just pop the next row.
if self._stream_result is None :
if self._stream_result is None:
try:
self._stream_result = self.client.stream_next()
if self._stream_result is None:
@ -332,14 +347,17 @@ class TabletConnection(object):
return None
if self._stream_result.reply.get('Err'):
self.__drain_conn_after_streaming_app_error()
raise gorpc.AppError(self._stream_result.reply['Err'].get('Message', 'Missing error message'))
raise gorpc.AppError(self._stream_result.reply['Err'].get(
'Message', 'Missing error message'))
except gorpc.GoRpcError as e:
raise convert_exception(e, str(self))
except:
logging.exception('gorpc low-level error')
raise
row = tuple(_make_row(self._stream_result.reply['Rows'][self._stream_result_index], self._stream_conversions))
row = tuple(
_make_row(self._stream_result.reply['Rows'][self._stream_result_index],
self._stream_conversions))
# If we are reading the last row, set us up to read more data.
self._stream_result_index += 1
if self._stream_result_index == len(self._stream_result.reply['Rows']):
@ -351,21 +369,25 @@ class TabletConnection(object):
def __drain_conn_after_streaming_app_error(self):
"""Drains the connection of all incoming streaming packets (ignoring them).
This is necessary for streaming calls which return application errors inside
the RPC response (instead of through the usual GoRPC error return).
This is because GoRPC always expects the last packet to be an error; either
the usual GoRPC application error return, or a special "end-of-stream" error.
This is necessary for streaming calls which return application
errors inside the RPC response (instead of through the usual GoRPC
error return). This is because GoRPC always expects the last
packet to be an error; either the usual GoRPC application error
return, or a special "end-of-stream" error.
If an application error is returned with the RPC response, there will still be
at least one more packet coming, as GoRPC has not seen anything that it
considers to be an error. If the connection is not drained of this last
packet, future reads from the wire will be off by one and will return errors.
If an application error is returned with the RPC response, there
will still be at least one more packet coming, as GoRPC has not
seen anything that it considers to be an error. If the connection
is not drained of this last packet, future reads from the wire
will be off by one and will return errors.
"""
next_result = self.client.stream_next()
if next_result is not None:
self.client.close()
raise gorpc.GoRpcError("Connection should only have one packet remaining"
" after streaming app error in RPC response.")
raise gorpc.GoRpcError(
'Connection should only have one packet remaining'
' after streaming app error in RPC response.')
def _make_row(row, conversions):
converted_row = []

Просмотреть файл

@ -4,7 +4,10 @@
#
# Use Python datetime module to handle date and time columns.
from datetime import date, datetime, time, timedelta
from datetime import date
from datetime import datetime
from datetime import time
from datetime import timedelta
from math import modf
from time import localtime
@ -17,18 +20,22 @@ Timestamp = datetime
DateTimeDeltaType = timedelta
DateTimeType = datetime
# Convert UNIX ticks into a date instance.
def DateFromTicks(ticks):
return date(*localtime(ticks)[:3])
# Convert UNIX ticks into a time instance.
def TimeFromTicks(ticks):
return time(*localtime(ticks)[3:6])
# Convert UNIX ticks into a datetime instance.
def TimestampFromTicks(ticks):
return datetime(*localtime(ticks)[:6])
def DateTimeOrNone(s):
if ' ' in s:
sep = ' '
@ -39,34 +46,45 @@ def DateTimeOrNone(s):
try:
d, t = s.split(sep, 1)
return datetime(*[ int(x) for x in d.split('-')+t.split(':') ])
except:
return datetime(*[int(x) for x in d.split('-')+t.split(':')])
except Exception:
return DateOrNone(s)
def TimeDeltaOrNone(s):
try:
h, m, s = s.split(':')
td = timedelta(hours=int(h), minutes=int(m), seconds=int(float(s)), microseconds=int(modf(float(s))[0]*1000000))
td = timedelta(
hours=int(h), minutes=int(m), seconds=int(float(s)),
microseconds=int(modf(float(s))[0]*1000000))
if h < 0:
return -td
else:
return td
except:
except Exception:
return None
def TimeOrNone(s):
try:
h, m, s = s.split(':')
return time(hour=int(h), minute=int(m), second=int(float(s)), microsecond=int(modf(float(s))[0]*1000000))
except:
return time(
hour=int(h), minute=int(m), second=int(float(s)),
microsecond=int(modf(float(s))[0]*1000000))
except Exception:
return None
def DateOrNone(s):
try: return date(*[ int(x) for x in s.split('-',2)])
except: return None
try:
return date(*[int(x) for x in s.split('-', 2)])
except Exception:
return None
def DateToString(d):
return d.isoformat()
def DateTimeToString(dt):
return dt.isoformat(' ')

Просмотреть файл

@ -4,9 +4,9 @@
import random
from zk import zkocc
from vtdb import topology
from vtdb import vtdb_logger
from zk import zkocc
class VTConnParams(object):
@ -29,9 +29,10 @@ class VTConnParams(object):
self.password = password
def get_db_params_for_tablet_conn(topo_client, keyspace_name, shard, db_type, timeout, user, password):
def get_db_params_for_tablet_conn(
topo_client, keyspace_name, shard, db_type, timeout, user, password):
db_params_list = []
db_key = "%s.%s.%s:vt" % (keyspace_name, shard, db_type)
db_key = '%s.%s.%s:vt' % (keyspace_name, shard, db_type)
# This will read the cached keyspace.
keyspace_object = topology.get_keyspace(keyspace_name)
@ -44,17 +45,20 @@ def get_db_params_for_tablet_conn(topo_client, keyspace_name, shard, db_type, ti
keyspace_name = new_keyspace
try:
end_points_data = topo_client.get_end_points('local', keyspace_name, shard, db_type)
end_points_data = topo_client.get_end_points(
'local', keyspace_name, shard, db_type)
except zkocc.ZkOccError as e:
vtdb_logger.get_logger().topo_zkocc_error('do data', db_key, e)
return []
except Exception as e:
vtdb_logger.get_logger().topo_exception('failed to get or parse topo data', db_key, e)
vtdb_logger.get_logger().topo_exception(
'failed to get or parse topo data', db_key, e)
return []
host_port_list = []
if 'Entries' not in end_points_data:
vtdb_logger.get_logger().topo_exception('topo server returned: ' + str(end_points_data), db_key, e)
vtdb_logger.get_logger().topo_exception(
'topo server returned: ' + str(end_points_data), db_key, e)
raise Exception('zkocc returned: %s' % str(end_points_data))
for entry in end_points_data['Entries']:
if 'vt' in entry['PortMap']:
@ -63,6 +67,8 @@ def get_db_params_for_tablet_conn(topo_client, keyspace_name, shard, db_type, ti
random.shuffle(host_port_list)
for host, port in host_port_list:
vt_params = VTConnParams(keyspace_name, shard, db_type, "%s:%s" % (host, port), timeout, user, password).__dict__
vt_params = VTConnParams(
keyspace_name, shard, db_type, '%s:%s' %
(host, port), timeout, user, password).__dict__
db_params_list.append(vt_params)
return db_params_list

Просмотреть файл

@ -25,7 +25,9 @@ from zk import zkocc
# keeps a global version of the topology
# This is a map of keyspace_name: (keyspace object, time when keyspace was last fetched)
# This is a map of keyspace_name: (keyspace object, time when keyspace
# was last fetched)
#
# eg - {'keyspace_name': (keyspace_object, time_of_last_fetch)}
# keyspace object is defined at py/vtdb/keyspace.py:Keyspace
__keyspace_map = {}
@ -100,7 +102,7 @@ def read_topology(zkocc_client, read_fqdb_keys=True):
db_keys = []
keyspace_list = zkocc_client.get_srv_keyspace_names('local')
# validate step
if len(keyspace_list) == 0:
if not keyspace_list:
vtdb_logger.get_logger().topo_empty_keyspace_list()
raise Exception('zkocc returned empty keyspace list')
for keyspace_name in keyspace_list:
@ -140,10 +142,12 @@ def get_host_port_by_name(topo_client, db_key):
vtdb_logger.get_logger().topo_zkocc_error('do data', db_key, e)
return []
except Exception as e:
vtdb_logger.get_logger().topo_exception('failed to get or parse topo data', db_key, e)
vtdb_logger.get_logger().topo_exception(
'failed to get or parse topo data', db_key, e)
return []
if 'Entries' not in data:
vtdb_logger.get_logger().topo_exception('topo server returned: ' + str(data), db_key, e)
vtdb_logger.get_logger().topo_exception(
'topo server returned: ' + str(data), db_key, e)
raise Exception('zkocc returned: %s' % str(data))
host_port_list = []

Просмотреть файл

@ -18,19 +18,21 @@ def register_conn_class(protocol, c):
def connect(protocol, *pargs, **kargs):
"""connect will return a dialed UpdateStreamConnection connection to
an update stream server.
"""Return a dialed UpdateStreamConnection to an update stream server.
Args:
protocol: the registered protocol to use.
args: passed to the registered protocol __init__ method.
protocol: The registered protocol to use.
*pargs: Passed to the registered protocol __init__ method.
**kargs: Passed to the registered protocol __init__ method.
Returns:
A dialed UpdateStreamConnection.
Raises:
ValueError: On bad protocol.
"""
if not protocol in update_stream_conn_classes:
raise Exception('Unknown update stream protocol', protocol)
if protocol not in update_stream_conn_classes:
raise ValueError('Unknown update stream protocol', protocol)
conn = update_stream_conn_classes[protocol](*pargs, **kargs)
conn.dial()
return conn
@ -59,11 +61,10 @@ class StreamEvent(object):
class UpdateStreamConnection(object):
"""UpdateStreamConnection is the interface for the update stream
client implementations.
All implementations must implement all these methods.
If something goes wrong with the connection, this object will be thrown out.
"""Te interface for the update stream client implementations.
All implementations must implement all these methods. If something
goes wrong with the connection, this object will be thrown out.
"""
def __init__(self, addr, timeout):

Просмотреть файл

@ -38,16 +38,17 @@ def reconnect(method):
if attempt >= self.max_attempts or self.in_txn:
self.close()
vtdb_logger.get_logger().vtclient_exception(self.keyspace, self.shard, self.db_type, e)
vtdb_logger.get_logger().vtclient_exception(
self.keyspace, self.shard, self.db_type, e)
raise dbexceptions.FatalError(*e.args)
if method.__name__ == 'begin':
time.sleep(BEGIN_RECONNECT_DELAY)
else:
time.sleep(RECONNECT_DELAY)
logging.info("Attempting to reconnect, %d", attempt)
logging.info('Attempting to reconnect, %d', attempt)
self.close()
self.connect()
logging.info("Successfully reconnected to %s", str(self.conn))
logging.info('Successfully reconnected to %s', str(self.conn))
return _run_with_reconnect
@ -89,11 +90,12 @@ class VtOCCConnection(object):
try:
return self._connect()
except dbexceptions.OperationalError as e:
vtdb_logger.get_logger().vtclient_exception(self.keyspace, self.shard, self.db_type, e)
vtdb_logger.get_logger().vtclient_exception(
self.keyspace, self.shard, self.db_type, e)
raise
def _connect(self):
db_key = "%s.%s.%s" % (self.keyspace, self.shard, self.db_type)
db_key = '%s.%s.%s' % (self.keyspace, self.shard, self.db_type)
db_params_list = get_vt_connection_params_list(self.topo_client,
self.keyspace,
self.shard,
@ -104,7 +106,8 @@ class VtOCCConnection(object):
if not db_params_list:
# no valid end-points were found, re-read the keyspace
self.resolve_topology()
raise dbexceptions.OperationalError("empty db params list - no db instance available for key %s" % db_key)
raise dbexceptions.OperationalError(
'empty db params list - no db instance available for key %s' % db_key)
db_exception = None
host_addr = None
# no retries here, since there is a higher level retry with reconnect.
@ -124,7 +127,7 @@ class VtOCCConnection(object):
self.resolve_topology()
raise dbexceptions.OperationalError(
'unable to create vt connection', db_key, host_addr, db_exception)
'unable to create vt connection', db_key, host_addr, db_exception)
def cursor(self, cursorclass=None, **kargs):
return (cursorclass or self.cursorclass)(self, **kargs)
@ -160,12 +163,14 @@ class VtOCCConnection(object):
sane_sql_list = []
sane_bind_vars_list = []
for sql, bind_variables in zip(sql_list, bind_variables_list):
sane_sql, sane_bind_vars = dbapi.prepare_query_bind_vars(sql, bind_variables)
sane_sql, sane_bind_vars = dbapi.prepare_query_bind_vars(
sql, bind_variables)
sane_sql_list.append(sane_sql)
sane_bind_vars_list.append(sane_bind_vars)
try:
result = self.conn._execute_batch(sane_sql_list, sane_bind_vars_list, as_transaction)
result = self.conn._execute_batch(
sane_sql_list, sane_bind_vars_list, as_transaction)
except dbexceptions.IntegrityError as e:
vtdb_logger.get_logger().integrity_error(e)
raise

Просмотреть файл

@ -21,7 +21,8 @@ class VtdbLogger(object):
# topo_keyspace_fetch is called when we successfully get a SrvKeyspace object.
def topo_keyspace_fetch(self, keyspace_name, topo_rtt):
logging.info("Fetched keyspace %s from topo_client in %f secs", keyspace_name, topo_rtt)
logging.info('Fetched keyspace %s from topo_client in %f secs',
keyspace_name, topo_rtt)
# topo_empty_keyspace_list is called when we get an empty list of
# keyspaces from topo server.
@ -32,7 +33,7 @@ class VtdbLogger(object):
# when reading a keyspace. This is within an exception handler.
def topo_bad_keyspace_data(self, keyspace_name):
logging.exception('error getting or parsing keyspace data for %s',
keyspace_name)
keyspace_name)
# topo_zkocc_error is called whenever we get a zkocc.ZkOccError
# when trying to resolve an endpoint.
@ -72,7 +73,7 @@ class VtdbLogger(object):
logging.warning('vtgatev2_exception: %s', e)
def log_private_data(self, private_data):
logging.info("Additional exception data %s", private_data)
logging.info('Additional exception data %s', private_data)
# registration mechanism for VtdbLogger

Просмотреть файл

@ -3,14 +3,18 @@
# be found in the LICENSE file.
import itertools
import operator
import re
from vtdb import cursor
from vtdb import dbexceptions
from vtdb import keyrange_constants
write_sql_pattern = re.compile('\s*(insert|update|delete)', re.IGNORECASE)
write_sql_pattern = re.compile(r'\s*(insert|update|delete)', re.IGNORECASE)
def ascii_lower(string):
"""Lower-case, but only in the ASCII range."""
return string.encode('utf8').lower().decode('utf8')
class VTGateCursor(object):
@ -28,7 +32,9 @@ class VTGateCursor(object):
_writable = None
routing = None
def __init__(self, connection, keyspace, tablet_type, keyspace_ids=None, keyranges=None, writable=False):
def __init__(
self, connection, keyspace, tablet_type, keyspace_ids=None,
keyranges=None, writable=False):
self._conn = connection
self.keyspace = keyspace
self.tablet_type = tablet_type
@ -74,20 +80,22 @@ class VTGateCursor(object):
return
write_query = bool(write_sql_pattern.match(sql))
# NOTE: This check may also be done at high-layers but adding it here for completion.
# NOTE: This check may also be done at high-layers but adding it
# here for completion.
if write_query:
if not self.is_writable():
raise dbexceptions.DatabaseError('DML on a non-writable cursor', sql)
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges,
not_in_transaction=(not self.is_writable()),
effective_caller_id=effective_caller_id)
self.results, self.rowcount, self.lastrowid, self.description = (
self._conn._execute(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges,
not_in_transaction=(not self.is_writable()),
effective_caller_id=effective_caller_id))
self.index = 0
return self.rowcount
@ -102,21 +110,22 @@ class VTGateCursor(object):
# This is by definition a scatter query, so raise exception.
write_query = bool(write_sql_pattern.match(sql))
if write_query:
raise dbexceptions.DatabaseError('execute_entity_ids is not allowed for write queries')
raise dbexceptions.DatabaseError(
'execute_entity_ids is not allowed for write queries')
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute_entity_ids(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
entity_keyspace_id_map,
entity_column_name,
not_in_transaction=(not self.is_writable()),
effective_caller_id=effective_caller_id)
self.results, self.rowcount, self.lastrowid, self.description = (
self._conn._execute_entity_ids(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
entity_keyspace_id_map,
entity_column_name,
not_in_transaction=(not self.is_writable()),
effective_caller_id=effective_caller_id))
self.index = 0
return self.rowcount
def fetchone(self):
if self.results is None:
raise dbexceptions.ProgrammingError('fetch called before execute')
@ -159,7 +168,8 @@ class VTGateCursor(object):
# sort the rows and then trim off the prepended sort columns
if sort_columns:
sorted_rows = list(sort_row_list_by_columns(self.fetchall(), sort_columns, desc_columns))[:limit]
sorted_rows = list(sort_row_list_by_columns(
self.fetchall(), sort_columns, desc_columns))[:limit]
else:
sorted_rows = itertools.islice(self.fetchall(), limit)
neutered_rows = [row[len(order_by_columns):] for row in sorted_rows]
@ -169,6 +179,7 @@ class VTGateCursor(object):
raise dbexceptions.NotSupportedError
def executemany(self, *pargs):
_ = pargs
raise dbexceptions.NotSupportedError
def nextset(self):
@ -203,6 +214,7 @@ class BatchVTGateCursor(VTGateCursor):
This only supports keyspace_ids right now since that is what
the underlying vtgate server supports.
"""
def __init__(self, connection, tablet_type, writable=False):
# rowset is [(results, rowcount, lastrowid, fields),]
self.rowsets = None
@ -210,7 +222,7 @@ class BatchVTGateCursor(VTGateCursor):
self.bind_vars_list = []
self.keyspace_list = []
self.keyspace_ids_list = []
VTGateCursor.__init__(self, connection, "", tablet_type, writable=writable)
VTGateCursor.__init__(self, connection, '', tablet_type, writable=writable)
def execute(self, sql, bind_variables, keyspace, keyspace_ids):
self.query_list.append(sql)
@ -240,8 +252,12 @@ class StreamVTGateCursor(VTGateCursor):
index = None
fetchmany_done = False
def __init__(self, connection, keyspace, tablet_type, keyspace_ids=None, keyranges=None, writable=False):
VTGateCursor.__init__(self, connection, keyspace, tablet_type, keyspace_ids=keyspace_ids, keyranges=keyranges)
def __init__(
self, connection, keyspace, tablet_type, keyspace_ids=None,
keyranges=None, writable=False):
VTGateCursor.__init__(
self, connection, keyspace, tablet_type, keyspace_ids=keyspace_ids,
keyranges=keyranges)
# pass kargs here in case higher level APIs need to push more data through
# for instance, a key value for shard mapping
@ -250,7 +266,7 @@ class StreamVTGateCursor(VTGateCursor):
raise dbexceptions.ProgrammingError('Streaming query cannot be writable')
self.description = None
x, y, z, self.description = self._conn._stream_execute(
_, _, _, self.description = self._conn._stream_execute(
sql,
bind_variables,
self.keyspace,
@ -279,7 +295,7 @@ class StreamVTGateCursor(VTGateCursor):
if self.fetchmany_done:
self.fetchmany_done = False
return result
for i in xrange(size):
for _ in xrange(size):
row = self.fetchone()
if row is None:
self.fetchmany_done = True
@ -327,7 +343,8 @@ class StreamVTGateCursor(VTGateCursor):
# assumes the leading columns are used for sorting
def sort_row_list_by_columns(row_list, sort_columns=(), desc_columns=()):
for column_index, column_name in reversed([x for x in enumerate(sort_columns)]):
for column_index, column_name in reversed(
[x for x in enumerate(sort_columns)]):
og = operator.itemgetter(column_index)
if type(row_list) != list:
row_list = sorted(

Просмотреть файл

@ -36,16 +36,21 @@ def exponential_backoff_retry(
num_retries=NUM_RETRIES,
backoff_multiplier=BACKOFF_MULTIPLIER,
max_delay_ms=MAX_DELAY_MS):
"""decorator for exponential backoff retry
"""Decorator for exponential backoff retry.
Log and raise exception if unsuccessful
Do not retry while in a session
Log and raise exception if unsuccessful.
Do not retry while in a session.
retry_exceptions: tuple of exceptions to check
initial_delay_ms: initial delay between retries in ms
num_retries: number max number of retries
backoff_multipler: multiplier for each retry e.g. 2 will double the retry delay
max_delay_ms: upper bound on retry delay
Args:
retry_exceptions: tuple of exceptions to check.
initial_delay_ms: initial delay between retries in ms.
num_retries: number max number of retries.
backoff_multiplier: multiplier for each retry e.g. 2 will double the
retry delay.
max_delay_ms: upper bound on retry delay.
Returns:
A decorator method that returns wrapped method.
"""
def decorator(method):
def wrapper(self, *args, **kwargs):
@ -62,7 +67,9 @@ def exponential_backoff_retry(
# and tablet_type from exception.
log_exception(e)
raise e
logging.error("retryable error: %s, retrying in %d ms, attempt %d of %d", e, delay, attempt, num_retries)
logging.error(
"retryable error: %s, retrying in %d ms, attempt %d of %d", e,
delay, attempt, num_retries)
time.sleep(delay/1000.0)
delay *= backoff_multiplier
delay = min(max_delay_ms, delay)

Просмотреть файл

@ -12,14 +12,13 @@ from net import gorpc
from vtdb import dbapi
from vtdb import dbexceptions
from vtdb import field_types
from vtdb import keyrange
from vtdb import keyspace
from vtdb import vtdb_logger
from vtdb import vtgate_client
from vtdb import vtgate_cursor
from vtdb import vtgate_utils
_errno_pattern = re.compile('\(errno (\d+)\)')
_errno_pattern = re.compile(r'\(errno (\d+)\)')
def handle_app_error(exc_args):
@ -41,6 +40,7 @@ def handle_app_error(exc_args):
def convert_exception(exc, *args, **kwargs):
"""This parses the protocol exceptions to the api interface exceptions.
This also logs the exception and increments the appropriate error counters.
Args:
@ -66,48 +66,53 @@ def convert_exception(exc, *args, **kwargs):
elif isinstance(exc, gorpc.GoRpcError):
new_exc = dbexceptions.FatalError(new_args)
keyspace = kwargs.get("keyspace", None)
tablet_type = kwargs.get("tablet_type", None)
keyspace_name = kwargs.get('keyspace', None)
tablet_type = kwargs.get('tablet_type', None)
vtgate_utils.log_exception(new_exc, keyspace=keyspace,
vtgate_utils.log_exception(new_exc, keyspace=keyspace_name,
tablet_type=tablet_type)
return new_exc
def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspace_ids, not_in_transaction):
def _create_req_with_keyspace_ids(
sql, new_binds, keyspace, tablet_type, keyspace_ids, not_in_transaction):
# keyspace_ids are Keyspace Ids packed to byte[]
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
new_binds = field_types.convert_bind_vars(new_binds)
req = {
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyspaceIds': keyspace_ids,
'NotInTransaction': not_in_transaction,
}
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyspaceIds': keyspace_ids,
'NotInTransaction': not_in_transaction,
}
return req
def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges, not_in_transaction):
def _create_req_with_keyranges(
sql, new_binds, keyspace, tablet_type, keyranges, not_in_transaction):
# keyranges are keyspace.KeyRange objects with start/end packed to byte[]
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
new_binds = field_types.convert_bind_vars(new_binds)
req = {
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyRanges': keyranges,
'NotInTransaction': not_in_transaction,
}
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyRanges': keyranges,
'NotInTransaction': not_in_transaction,
}
return req
# A simple, direct connection to the vttablet query server.
# This is shard-unaware and only handles the most basic communication.
# If something goes wrong, this object should be thrown away and a new one instantiated.
class VTGateConnection(vtgate_client.VTGateClient):
"""A simple, direct connection to the vttablet query server.
This is shard-unaware and only handles the most basic communication.
If something goes wrong, this object should be thrown away and a new
one instantiated.
"""
session = None
_stream_fields = None
_stream_conversions = None
@ -118,7 +123,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
keyfile=None, certfile=None):
self.addr = addr
self.timeout = timeout
self.client = bsonrpc.BsonRpcClient(addr, timeout, user, password, keyfile=keyfile, certfile=certfile)
self.client = bsonrpc.BsonRpcClient(
addr, timeout, user, password, keyfile=keyfile, certfile=certfile)
self.logger_object = vtdb_logger.get_logger()
def __str__(self):
@ -204,13 +210,18 @@ class VTGateConnection(vtgate_client.VTGateClient):
exec_method = None
req = None
if keyspace_ids is not None:
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
req = _create_req_with_keyspace_ids(
sql, bind_variables, keyspace, tablet_type, keyspace_ids,
not_in_transaction)
exec_method = 'VTGate.ExecuteKeyspaceIds'
elif keyranges is not None:
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
req = _create_req_with_keyranges(
sql, bind_variables, keyspace, tablet_type, keyranges,
not_in_transaction)
exec_method = 'VTGate.ExecuteKeyRanges'
else:
raise dbexceptions.ProgrammingError('_execute called without specifying keyspace_ids or keyranges')
raise dbexceptions.ProgrammingError(
'_execute called without specifying keyspace_ids or keyranges')
self._add_caller_id(req, effective_caller_id)
self._add_session(req)
@ -301,13 +312,13 @@ class VTGateConnection(vtgate_client.VTGateClient):
raise
return results, rowcount, lastrowid, fields
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
def _execute_batch(
self, sql_list, bind_variables_list, keyspace_list, keyspace_ids_list,
tablet_type, as_transaction, effective_caller_id=None):
query_list = []
for sql, bind_vars, keyspace, keyspace_ids in zip(sql_list, bind_variables_list, keyspace_list, keyspace_ids_list):
for sql, bind_vars, keyspace, keyspace_ids in zip(
sql_list, bind_variables_list, keyspace_list, keyspace_ids_list):
sql, bind_vars = dbapi.prepare_query_bind_vars(sql, bind_vars)
query = {}
query['Sql'] = sql
@ -329,7 +340,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
response = self.client.call('VTGate.ExecuteBatchKeyspaceIds', req)
self._update_session(response)
if 'Error' in response.reply and response.reply['Error']:
raise gorpc.AppError(response.reply['Error'], 'VTGate.ExecuteBatchKeyspaceIds')
raise gorpc.AppError(
response.reply['Error'], 'VTGate.ExecuteBatchKeyspaceIds')
for reply in response.reply['List']:
fields = []
conversions = []
@ -365,13 +377,18 @@ class VTGateConnection(vtgate_client.VTGateClient):
exec_method = None
req = None
if keyspace_ids is not None:
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
req = _create_req_with_keyspace_ids(
sql, bind_variables, keyspace, tablet_type, keyspace_ids,
not_in_transaction)
exec_method = 'VTGate.StreamExecuteKeyspaceIds'
elif keyranges is not None:
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
req = _create_req_with_keyranges(
sql, bind_variables, keyspace, tablet_type, keyranges,
not_in_transaction)
exec_method = 'VTGate.StreamExecuteKeyRanges'
else:
raise dbexceptions.ProgrammingError('_stream_execute called without specifying keyspace_ids or keyranges')
raise dbexceptions.ProgrammingError(
'_stream_execute called without specifying keyspace_ids or keyranges')
self._add_caller_id(req, effective_caller_id)
self._add_session(req)
@ -387,7 +404,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
for field in reply['Fields']:
self._stream_fields.append((field['Name'], field['Type']))
self._stream_conversions.append(field_types.conversions.get(field['Type']))
self._stream_conversions.append(
field_types.conversions.get(field['Type']))
except gorpc.GoRpcError as e:
self.logger_object.log_private_data(bind_variables)
raise convert_exception(e, str(self), sql, keyspace_ids, keyranges,
@ -410,7 +428,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
self._stream_result_index = None
return None
# A session message, if any comes separately with no rows
if 'Session' in self._stream_result.reply and self._stream_result.reply['Session']:
if ('Session' in self._stream_result.reply and
self._stream_result.reply['Session']):
self.session = self._stream_result.reply['Session']
self._stream_result = None
continue
@ -420,11 +439,14 @@ class VTGateConnection(vtgate_client.VTGateClient):
logging.exception('gorpc low-level error')
raise
row = tuple(_make_row(self._stream_result.reply['Result']['Rows'][self._stream_result_index], self._stream_conversions))
row = tuple(_make_row(
self._stream_result.reply['Result']['Rows'][self._stream_result_index],
self._stream_conversions))
# If we are reading the last row, set us up to read more data.
self._stream_result_index += 1
if self._stream_result_index == len(self._stream_result.reply['Result']['Rows']):
if (self._stream_result_index ==
len(self._stream_result.reply['Result']['Rows'])):
self._stream_result = None
self._stream_result_index = 0
@ -468,7 +490,7 @@ def get_params_for_vtgate_conn(vtgate_addrs, timeout, user=None, password=None):
random.shuffle(vtgate_addrs)
addrs = vtgate_addrs
else:
raise dbexceptions.Error("Wrong type for vtgate addrs %s" % vtgate_addrs)
raise dbexceptions.Error('Wrong type for vtgate addrs %s' % vtgate_addrs)
for addr in addrs:
vt_params = dict()
@ -485,7 +507,9 @@ def connect(vtgate_addrs, timeout, user=None, password=None):
user=user, password=password)
if not db_params_list:
raise dbexceptions.OperationalError("empty db params list - no db instance available for vtgate_addrs %s" % vtgate_addrs)
raise dbexceptions.OperationalError(
'empty db params list - no db instance available for vtgate_addrs %s' %
vtgate_addrs)
db_exception = None
host_addr = None
@ -501,6 +525,6 @@ def connect(vtgate_addrs, timeout, user=None, password=None):
logging.warning('db connection failed: %s, %s', host_addr, e)
raise dbexceptions.OperationalError(
'unable to create vt connection', host_addr, db_exception)
'unable to create vt connection', host_addr, db_exception)
vtgate_client.register_conn_class('gorpc', VTGateConnection)

Просмотреть файл

@ -4,18 +4,17 @@
from itertools import izip
import logging
import random
import re
from net import bsonrpc
from net import gorpc
from vtdb import cursorv3
from vtdb import dbexceptions
from vtdb import field_types
from vtdb import vtdb_logger
from vtdb import cursorv3
_errno_pattern = re.compile('\(errno (\d+)\)')
_errno_pattern = re.compile(r'\(errno (\d+)\)')
def log_exception(method):
@ -26,8 +25,8 @@ def log_exception(method):
method based on the exception raised.
Args:
exc: exception raised by calling code
args: additional args for the exception.
method: Method that takes exc, *args, where exc is an exception raised
by calling code, args are additional args for the exception.
Returns:
Decorated method.
@ -79,16 +78,16 @@ def convert_exception(exc, *args):
def _create_req(sql, new_binds, tablet_type, not_in_transaction):
new_binds = field_types.convert_bind_vars(new_binds)
req = {
'Sql': sql,
'BindVariables': new_binds,
'TabletType': tablet_type,
'NotInTransaction': not_in_transaction,
}
'Sql': sql,
'BindVariables': new_binds,
'TabletType': tablet_type,
'NotInTransaction': not_in_transaction,
}
return req
# This utilizes the V3 API of VTGate.
class VTGateConnection(object):
"""This utilizes the V3 API of VTGate."""
session = None
_stream_fields = None
_stream_conversions = None
@ -163,7 +162,8 @@ class VTGateConnection(object):
if 'Session' in response.reply and response.reply['Session']:
self.session = response.reply['Session']
def _execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
def _execute(
self, sql, bind_variables, tablet_type, not_in_transaction=False):
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
self._add_session(req)
@ -198,8 +198,8 @@ class VTGateConnection(object):
raise
return results, rowcount, lastrowid, fields
def _execute_batch(self, sql_list, bind_variables_list, tablet_type, as_transaction):
def _execute_batch(
self, sql_list, bind_variables_list, tablet_type, as_transaction):
query_list = []
for sql, bind_vars in zip(sql_list, bind_variables_list):
query = {}
@ -247,7 +247,8 @@ class VTGateConnection(object):
# we return the fields for the response, and the column conversions
# the conversions will need to be passed back to _stream_next
# (that way we avoid using a member variable here for such a corner case)
def _stream_execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
def _stream_execute(
self, sql, bind_variables, tablet_type, not_in_transaction=False):
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
self._add_session(req)
@ -262,7 +263,8 @@ class VTGateConnection(object):
for field in reply['Fields']:
self._stream_fields.append((field['Name'], field['Type']))
self._stream_conversions.append(field_types.conversions.get(field['Type']))
self._stream_conversions.append(
field_types.conversions.get(field['Type']))
except gorpc.GoRpcError as e:
self.logger_object.log_private_data(bind_variables)
raise convert_exception(e, str(self), sql)
@ -284,7 +286,8 @@ class VTGateConnection(object):
self._stream_result_index = None
return None
# A session message, if any comes separately with no rows
if 'Session' in self._stream_result.reply and self._stream_result.reply['Session']:
if ('Session' in self._stream_result.reply and
self._stream_result.reply['Session']):
self.session = self._stream_result.reply['Session']
self._stream_result = None
continue
@ -298,11 +301,14 @@ class VTGateConnection(object):
logging.exception('gorpc low-level error')
raise
row = tuple(_make_row(self._stream_result.reply['Result']['Rows'][self._stream_result_index], self._stream_conversions))
row = tuple(_make_row(
self._stream_result.reply['Result']['Rows'][self._stream_result_index],
self._stream_conversions))
# If we are reading the last row, set us up to read more data.
self._stream_result_index += 1
if self._stream_result_index == len(self._stream_result.reply['Result']['Rows']):
if (self._stream_result_index ==
len(self._stream_result.reply['Result']['Rows'])):
self._stream_result = None
self._stream_result_index = 0

Просмотреть файл

@ -229,7 +229,7 @@ def _create_where_clause_for_str_keyspace(key_range, keyspace_col_name):
i += 1
bind_vars[bind_name] = kr_min
if kr_max != keyrange_constants.MAX_KEY:
if where_clause != '':
if where_clause:
where_clause += ' AND '
bind_name = '%s%d' % (keyspace_col_name, i)
where_clause += 'hex(%s) < ' % keyspace_col_name + '%(' + bind_name + ')s'
@ -262,7 +262,7 @@ def _create_where_clause_for_int_keyspace(key_range, keyspace_col_name):
i += 1
bind_vars[bind_name] = kr_min
if kr_max is not None:
if where_clause != '':
if where_clause:
where_clause += ' AND '
bind_name = '%s%d' % (keyspace_col_name, i)
where_clause += '%s < ' % keyspace_col_name + '%(' + bind_name + ')s'

Просмотреть файл

@ -1,8 +1,11 @@
# Implement a sensible wrapper that treats python objects as dictionaries
# with sensible restrictions on serialization.
"""Implement a sensible wrapper that treats python objects as dictionaries.
Has sensible restrictions on serialization.
"""
import json
def _default(o):
if hasattr(o, '_serializable_attributes'):
return dict([(k, v)
@ -13,13 +16,15 @@ def _default(o):
_default_kargs = {'default': _default,
'sort_keys': True,
'indent': 2,
}
}
def dump(*pargs, **kargs):
_kargs = _default_kargs.copy()
_kargs.update(kargs)
return json.dump(*pargs, **_kargs)
def dumps(*pargs, **kargs):
_kargs = _default_kargs.copy()
_kargs.update(kargs)

Просмотреть файл

@ -1,10 +1,11 @@
# zkns - naming service.
#
# This uses zookeeper to resolve a list of servers to answer a
# particular query type.
#
# Additionally, config information can be embedded nearby for future updates
# to the python client.
"""zkns - naming service.
This uses zookeeper to resolve a list of servers to answer a
particular query type.
Additionally, config information can be embedded nearby for future updates
to the python client.
"""
import collections
import json
@ -16,6 +17,7 @@ from zk import zkjson
SrvEntry = collections.namedtuple('SrvEntry',
('host', 'port', 'priority', 'weight'))
class ZknsError(Exception):
pass
@ -28,12 +30,13 @@ class ZknsAddr(zkjson.ZkJsonObject):
class ZknsAddrs(zkjson.ZkJsonObject):
# NOTE: attributes match Go implementation, hence capitalization
_serializable_attributes = ('entries',)
def __init__(self):
self.entries = []
def _sorted_by_srv_priority(entries):
# Priority is ascending, weight is descending.
"""Priority is ascending, weight is descending."""
entries.sort(key=lambda x: (x.priority, -x.weight))
priority_map = collections.defaultdict(list)
@ -41,7 +44,7 @@ def _sorted_by_srv_priority(entries):
priority_map[entry.priority].append(entry)
shuffled_entries = []
for priority, priority_entries in sorted(priority_map.iteritems()):
for unused_priority, priority_entries in sorted(priority_map.iteritems()):
if len(priority_entries) <= 1:
shuffled_entries.extend(priority_entries)
continue
@ -62,6 +65,7 @@ def _sorted_by_srv_priority(entries):
return shuffled_entries
def _get_addrs(zconn, zk_path):
data = zconn.get_data(zk_path)
addrs = ZknsAddrs()
@ -72,6 +76,7 @@ def _get_addrs(zconn, zk_path):
addrs.entries.append(addr)
return addrs
# zkns_name: /zk/cell/vt/ns/path:_port - port is optional
def lookup_name(zconn, zkns_name):
if ':' in zkns_name:

Просмотреть файл

@ -7,6 +7,7 @@ import threading
from net import bsonrpc
from net import gorpc
class ZkOccError(Exception):
pass
@ -35,9 +36,12 @@ class ZkOccError(Exception):
# pzxid long
#
# A simple, direct connection to a single zkocc server. Doesn't retry.
# You probably want to use ZkOccConnection instead.
class SimpleZkOccConnection(object):
"""A simple, direct connection to a single zkocc server.
Doesn't retry. You probably want to use ZkOccConnection instead.
"""
def __init__(self, addr, timeout, user=None, password=None):
self.client = bsonrpc.BsonRpcClient(addr, timeout, user, password)
@ -75,13 +79,18 @@ class SimpleZkOccConnection(object):
return self._call('TopoReader.GetSrvKeyspace', cell=cell, keyspace=keyspace)
def get_end_points(self, cell, keyspace, shard, tablet_type):
return self._call('TopoReader.GetEndPoints', cell=cell, keyspace=keyspace, shard=shard, tablet_type=tablet_type)
return self._call(
'TopoReader.GetEndPoints', cell=cell, keyspace=keyspace, shard=shard,
tablet_type=tablet_type)
# A meta-connection that can connect to multiple alternate servers, and will
# retry a couple times. Calling dial before get/getv/children is optional,
# and will only do anything at all if authentication is enabled.
class ZkOccConnection(object):
"""A meta-connection that can connect to multiple alternate servers.
This will retry a couple times. Calling dial before
get/getv/children is optional, and will only do anything at all if
authentication is enabled.
"""
max_attempts = 2
max_dial_attempts = 10
@ -92,7 +101,8 @@ class ZkOccConnection(object):
self.local_cell = local_cell
if bool(user) != bool(password):
raise ValueError("You must provide either both or none of user and password.")
raise ValueError(
'You must provide either both or none of user and password.')
self.user = user
self.password = password
@ -100,7 +110,7 @@ class ZkOccConnection(object):
self.lock = threading.Lock()
def _resolve_path(self, zk_path):
# Maps a 'meta-path' to a cell specific path.
"""Maps a 'meta-path' to a cell specific path."""
# '/zk/local/blah' -> '/zk/vb/blah'
parts = zk_path.split('/')
@ -123,9 +133,11 @@ class ZkOccConnection(object):
if self.simple_conn:
self.simple_conn.close()
addrs = random.sample(self.addrs, min(self.max_dial_attempts, len(self.addrs)))
addrs = random.sample(
self.addrs, min(self.max_dial_attempts, len(self.addrs)))
for a in addrs:
self.simple_conn = SimpleZkOccConnection(a, self.timeout, self.user, self.password)
self.simple_conn = SimpleZkOccConnection(
a, self.timeout, self.user, self.password)
try:
self.simple_conn.dial()
return
@ -133,7 +145,7 @@ class ZkOccConnection(object):
pass
self.simple_conn = None
raise ZkOccError("Cannot dial to any server, tried: %s" % addrs)
raise ZkOccError('Cannot dial to any server, tried: %s' % addrs)
def close(self):
if self.simple_conn:
@ -151,9 +163,13 @@ class ZkOccConnection(object):
return getattr(self.simple_conn, client_method)(*args, **kwargs)
except Exception as e:
attempt += 1
logging.warning('zkocc: %s command failed %d times: %s', client_method, attempt, e)
logging.warning(
'zkocc: %s command failed %d times: %s',
client_method, attempt, e)
if attempt >= self.max_attempts:
raise ZkOccError('zkocc %s command failed %d times: %s' % (client_method, attempt, e))
raise ZkOccError(
'zkocc %s command failed %d times: %s' %
(client_method, attempt, e))
# try the next server if there is one, or retry our only server
self.dial()
@ -190,10 +206,15 @@ class ZkOccConnection(object):
def children(self, path):
return self._call('children', self._resolve_path(path))
# use this class for faking out a zkocc client. The startup config values
# can be loaded from a json file. After that, they can be mass-altered
# to replace default values with test-specific values, for instance.
class FakeZkOccConnection(object):
"""Use this class for faking out a zkocc client.
The startup config values can be loaded from a json file. After
that, they can be mass-altered to replace default values with
test-specific values, for instance.
"""
def __init__(self, local_cell):
self.data = {}
self.local_cell = local_cell
@ -215,7 +236,7 @@ class FakeZkOccConnection(object):
self.data[key] = data.replace(before, after)
def _resolve_path(self, zk_path):
# Maps a 'meta-path' to a cell specific path.
"""Maps a 'meta-path' to a cell specific path."""
# '/zk/local/blah' -> '/zk/vb/blah'
parts = zk_path.split('/')
@ -238,25 +259,25 @@ class FakeZkOccConnection(object):
def get(self, path):
path = self._resolve_path(path)
if not path in self.data:
raise ZkOccError("FakeZkOccConnection: not found: " + path)
if path not in self.data:
raise ZkOccError('FakeZkOccConnection: not found: ' + path)
return {
'Data':self.data[path],
'Children':[]
'Data': self.data[path],
'Children': []
}
def getv(self, paths):
raise ZkOccError("FakeZkOccConnection: not found: " + " ".join(paths))
raise ZkOccError('FakeZkOccConnection: not found: ' + ' '.join(paths))
def children(self, path):
path = self._resolve_path(path)
children = [os.path.basename(node) for node in self.data
if os.path.dirname(node) == path]
if len(children) == 0:
raise ZkOccError("FakeZkOccConnection: not found: " + path)
if not children:
raise ZkOccError('FakeZkOccConnection: not found: ' + path)
return {
'Data':'',
'Children':children
'Data': '',
'Children': children
}
# New API. For this fake object, it is based on the old API.
@ -271,7 +292,7 @@ class FakeZkOccConnection(object):
try:
data = self.get(keyspace_path)['Data']
if not data:
raise ZkOccError("FakeZkOccConnection: empty keyspace: " + keyspace)
raise ZkOccError('FakeZkOccConnection: empty keyspace: ' + keyspace)
result = json.loads(data)
# for convenience, we store the KeyRange as hex, but we need to
# decode it here, as BSON RPC sends it as binary.
@ -289,7 +310,7 @@ class FakeZkOccConnection(object):
try:
data = self.get(zk_path)['Data']
if not data:
raise ZkOccError("FakeZkOccConnection: empty end point: " + zk_path)
raise ZkOccError('FakeZkOccConnection: empty end point: ' + zk_path)
return json.loads(data)
except Exception as e:
raise ZkOccError('FakeZkOccConnection: invalid end point', zk_path, e)