зеркало из https://github.com/github/vitess-gh.git
Fix some basic lint issues in vitess/py.
This commit is contained in:
Родитель
2dfda59b5a
Коммит
a78bbf14ab
|
@ -1,9 +1,10 @@
|
|||
from distutils.core import setup, Extension
|
||||
from distutils.core import Extension
|
||||
from distutils.core import setup
|
||||
|
||||
cbson = Extension('cbson',
|
||||
sources = ['cbson.c'])
|
||||
sources=['cbson.c'])
|
||||
|
||||
setup(name = 'cbson',
|
||||
version = '0.1',
|
||||
description = 'Fast BSON decoding via C',
|
||||
ext_modules = [cbson])
|
||||
setup(name='cbson',
|
||||
version='0.1',
|
||||
description='Fast BSON decoding via C',
|
||||
ext_modules=[cbson])
|
||||
|
|
|
@ -12,16 +12,20 @@ import sys
|
|||
sys.path.append(os.path.split(os.path.abspath(__file__))[0])
|
||||
import cbson
|
||||
|
||||
|
||||
def test_load_empty():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> cbson.loads('')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
BSONError: empty buffer
|
||||
"""
|
||||
|
||||
|
||||
def test_load_binary():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'world': 'hello'})
|
||||
>>> s
|
||||
'\x16\x00\x00\x00\x05world\x00\x05\x00\x00\x00\x00hello\x00'
|
||||
|
@ -29,8 +33,10 @@ def test_load_binary():
|
|||
{'world': 'hello'}
|
||||
"""
|
||||
|
||||
|
||||
def test_load_string():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'world': u'hello \u00fc'})
|
||||
>>> s
|
||||
'\x19\x00\x00\x00\x02world\x00\t\x00\x00\x00hello \xc3\xbc\x00\x00'
|
||||
|
@ -38,8 +44,10 @@ def test_load_string():
|
|||
{'world': u'hello \xfc'}
|
||||
"""
|
||||
|
||||
|
||||
def test_load_int():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'int': 1334})
|
||||
>>> s
|
||||
'\x0e\x00\x00\x00\x10int\x006\x05\x00\x00\x00'
|
||||
|
@ -47,9 +55,11 @@ def test_load_int():
|
|||
{'int': 1334}
|
||||
"""
|
||||
|
||||
|
||||
# the library doesn't allow creation of these, so just check the unpacking
|
||||
def test_unpack_uint64():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = '\x12\x00\x00\x00\x3fint\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
>>> cbson.loads(s)
|
||||
{'int': 1L}
|
||||
|
@ -58,16 +68,20 @@ def test_unpack_uint64():
|
|||
{'int': 10376293541461622784L}
|
||||
"""
|
||||
|
||||
|
||||
def test_bool():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> cbson.loads(cbson.dumps({'yes': True, 'no': False}))['yes']
|
||||
True
|
||||
>>> cbson.loads(cbson.dumps({'yes': True, 'no': False}))['no']
|
||||
False
|
||||
"""
|
||||
|
||||
|
||||
def test_none():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'none': None})
|
||||
>>> s
|
||||
'\x0b\x00\x00\x00\nnone\x00\x00'
|
||||
|
@ -75,8 +89,10 @@ def test_none():
|
|||
{'none': None}
|
||||
"""
|
||||
|
||||
|
||||
def test_array():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'lst': [1, 2, 3, 4, 'hello', None]})
|
||||
>>> s
|
||||
';\x00\x00\x00\x04lst\x001\x00\x00\x00\x100\x00\x01\x00\x00\x00\x101\x00\x02\x00\x00\x00\x102\x00\x03\x00\x00\x00\x103\x00\x04\x00\x00\x00\x054\x00\x05\x00\x00\x00\x00hello\n5\x00\x00\x00'
|
||||
|
@ -84,15 +100,19 @@ def test_array():
|
|||
{'lst': [1, 2, 3, 4, 'hello', None]}
|
||||
"""
|
||||
|
||||
|
||||
def test_dict():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> d = {'a': None, 'b': 2, 'c': 'hello', 'd': 1.02, 'e': ['a', 'b', 'c']}
|
||||
>>> cbson.loads(cbson.dumps(d)) == d
|
||||
True
|
||||
"""
|
||||
|
||||
|
||||
def test_nested_dict():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'a': {'b': {'c': 0}}})
|
||||
>>> s
|
||||
'\x1c\x00\x00\x00\x03a\x00\x14\x00\x00\x00\x03b\x00\x0c\x00\x00\x00\x10c\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
|
@ -100,8 +120,10 @@ def test_nested_dict():
|
|||
{'a': {'b': {'c': 0}}}
|
||||
"""
|
||||
|
||||
|
||||
def test_dict_in_array():
|
||||
r"""
|
||||
r"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'a': [{'b': 0}, {'c': 1}]})
|
||||
>>> s
|
||||
'+\x00\x00\x00\x04a\x00#\x00\x00\x00\x030\x00\x0c\x00\x00\x00\x10b\x00\x00\x00\x00\x00\x00\x031\x00\x0c\x00\x00\x00\x10c\x00\x01\x00\x00\x00\x00\x00\x00'
|
||||
|
@ -109,32 +131,40 @@ def test_dict_in_array():
|
|||
{'a': [{'b': 0}, {'c': 1}]}
|
||||
"""
|
||||
|
||||
|
||||
def test_eob1():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> cbson.loads(BSON(BSON_TAG(1), NULL_BYTE))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
BSONError: unexpected end of buffer: wanted 8 bytes at buffer[6] for double
|
||||
"""
|
||||
|
||||
|
||||
def test_eob2():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> cbson.loads(BSON(BSON_TAG(2), NULL_BYTE))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
BSONError: unexpected end of buffer: wanted 4 bytes at buffer[6] for string-length
|
||||
"""
|
||||
|
||||
|
||||
def test_eob_cstring():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> cbson.loads(BSON(BSON_TAG(1)))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
BSONError: unexpected end of buffer: non-terminated cstring at buffer[5] for element-name
|
||||
"""
|
||||
|
||||
|
||||
def test_decode_next():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> s = cbson.dumps({'a': 1}) + cbson.dumps({'b': 2.0}) + cbson.dumps({'c': [None]})
|
||||
>>> cbson.decode_next(s)
|
||||
(12, {'a': 1})
|
||||
|
@ -144,8 +174,10 @@ def test_decode_next():
|
|||
(44, {'c': [None]})
|
||||
"""
|
||||
|
||||
|
||||
def test_decode_next_eob():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> s_full = cbson.dumps({'a': 1}) + cbson.dumps({'b': 2.0})
|
||||
>>> s = s_full[:-1]
|
||||
>>> cbson.decode_next(s)
|
||||
|
@ -164,8 +196,10 @@ def test_decode_next_eob():
|
|||
BSONBufferTooShort: ('buffer too short: buffer[12:] does not contain 12 bytes for document', 2)
|
||||
"""
|
||||
|
||||
|
||||
def test_encode_recursive():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> a = []
|
||||
>>> a.append(a)
|
||||
>>> cbson.dumps({"x": a})
|
||||
|
@ -174,8 +208,10 @@ def test_encode_recursive():
|
|||
ValueError: object too deeply nested to BSON encode
|
||||
"""
|
||||
|
||||
|
||||
def test_encode_invalid():
|
||||
"""
|
||||
"""Doctest.
|
||||
|
||||
>>> cbson.dumps(object())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
|
@ -186,24 +222,26 @@ def test_encode_invalid():
|
|||
TypeError: unsupported type for BSON encode
|
||||
"""
|
||||
|
||||
|
||||
def BSON(*l):
|
||||
buf = ''.join(l)
|
||||
return struct.pack('i', len(buf)+4) + buf
|
||||
|
||||
|
||||
def BSON_TAG(type_id):
|
||||
return struct.pack('b', type_id)
|
||||
|
||||
NULL_BYTE = '\x00'
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
cbson
|
||||
|
||||
import doctest
|
||||
print "Running selftest:"
|
||||
print 'Running selftest:'
|
||||
status = doctest.testmod(sys.modules[__name__])
|
||||
if status[0]:
|
||||
print "*** %s tests of %d failed." % status
|
||||
print '*** %s tests of %d failed.' % status
|
||||
sys.exit(1)
|
||||
else:
|
||||
print "--- %s tests passed." % status[1]
|
||||
print '--- %s tests passed.' % status[1]
|
||||
sys.exit(0)
|
||||
|
|
|
@ -12,15 +12,18 @@ import cbson
|
|||
# crash the decoder, before we fixed the issues.
|
||||
# Only base64 encoded so they aren't a million characters long and/or full
|
||||
# of escape sequences.
|
||||
KNOWN_BAD = ["VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
|
||||
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAB0A",
|
||||
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAADTMwAEAAAAEDQABQAAAAU"
|
||||
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAAAA",
|
||||
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
|
||||
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQ2gDAAAARFMAAAAA",
|
||||
]
|
||||
KNOWN_BAD = [
|
||||
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
|
||||
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAB0A",
|
||||
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAADTMwAEAAAAEDQABQAAAAU"
|
||||
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQwADAAAARFMAAAAA",
|
||||
"VAAAAARBAEwAAAAQMAABAAAAEDEAAgAAABAyAAMAAAAQMwAEAAAAEDQABQAAAAU"
|
||||
"1AAEAAAAANgI2AAIAAAA3AAM3AA8AAAACQ2gDAAAARFMAAAAA",
|
||||
]
|
||||
|
||||
|
||||
class CbsonTest(unittest.TestCase):
|
||||
|
||||
def test_short_string_segfaults(self):
|
||||
a = cbson.dumps({"A": [1, 2, 3, 4, 5, "6", u"7", {"C": u"DS"}]})
|
||||
for i in range(len(a))[1:]:
|
||||
|
@ -40,7 +43,7 @@ class CbsonTest(unittest.TestCase):
|
|||
def test_random_segfaults(self):
|
||||
a = cbson.dumps({"A": [1, 2, 3, 4, 5, "6", u"7", {"C": u"DS"}]})
|
||||
sys.stdout.write("\nQ: %s\n" % (binascii.hexlify(a),))
|
||||
for i in range(1000):
|
||||
for _ in range(1000):
|
||||
l = [c for c in a]
|
||||
l[random.randint(4, len(a)-1)] = chr(random.randint(0, 255))
|
||||
try:
|
||||
|
@ -54,4 +57,3 @@ class CbsonTest(unittest.TestCase):
|
|||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
|
||||
# Go-style RPC client using BSON as the codec.
|
||||
|
||||
import bson
|
||||
import hmac
|
||||
import struct
|
||||
|
||||
import bson
|
||||
try:
|
||||
# use optimized cbson which has slightly different API
|
||||
import cbson
|
||||
|
@ -25,11 +26,14 @@ len_struct = struct.Struct('<i')
|
|||
unpack_length = len_struct.unpack_from
|
||||
len_struct_size = len_struct.size
|
||||
|
||||
|
||||
class BsonRpcClient(gorpc.GoRpcClient):
|
||||
|
||||
def __init__(self, addr, timeout, user=None, password=None,
|
||||
keyfile=None, certfile=None):
|
||||
if bool(user) != bool(password):
|
||||
raise ValueError("You must provide either both or none of user and password.")
|
||||
raise ValueError(
|
||||
'You must provide either both or none of user and password.')
|
||||
self.addr = addr
|
||||
self.user = user
|
||||
self.password = password
|
||||
|
@ -37,7 +41,8 @@ class BsonRpcClient(gorpc.GoRpcClient):
|
|||
uri = 'http://%s/_bson_rpc_/auth' % self.addr
|
||||
else:
|
||||
uri = 'http://%s/_bson_rpc_' % self.addr
|
||||
gorpc.GoRpcClient.__init__(self, uri, timeout, keyfile=keyfile, certfile=certfile)
|
||||
gorpc.GoRpcClient.__init__(
|
||||
self, uri, timeout, keyfile=keyfile, certfile=certfile)
|
||||
|
||||
def dial(self):
|
||||
gorpc.GoRpcClient.dial(self)
|
||||
|
@ -49,10 +54,11 @@ class BsonRpcClient(gorpc.GoRpcClient):
|
|||
raise
|
||||
|
||||
def authenticate(self):
|
||||
challenge = self.call('AuthenticatorCRAMMD5.GetNewChallenge', "").reply['Challenge']
|
||||
challenge = self.call(
|
||||
'AuthenticatorCRAMMD5.GetNewChallenge', '').reply['Challenge']
|
||||
# CRAM-MD5 authentication.
|
||||
proof = self.user + " " + hmac.HMAC(self.password, challenge).hexdigest()
|
||||
self.call('AuthenticatorCRAMMD5.Authenticate', {"Proof": proof})
|
||||
proof = self.user + ' ' + hmac.HMAC(self.password, challenge).hexdigest()
|
||||
self.call('AuthenticatorCRAMMD5.Authenticate', {'Proof': proof})
|
||||
|
||||
def encode_request(self, req):
|
||||
try:
|
||||
|
@ -81,7 +87,7 @@ class BsonRpcClient(gorpc.GoRpcClient):
|
|||
# decode the payload length and see if we have enough
|
||||
body_len = unpack_length(data, header_len)[0]
|
||||
if data_len < header_len + body_len:
|
||||
return None, header_len + body_len - data_len
|
||||
return None, header_len + body_len - data_len
|
||||
|
||||
# we have enough data, decode it all
|
||||
try:
|
||||
|
|
|
@ -9,16 +9,18 @@
|
|||
|
||||
import errno
|
||||
import select
|
||||
import ssl
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
import urlparse
|
||||
|
||||
_lastStreamResponseError = 'EOS'
|
||||
|
||||
|
||||
class GoRpcError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(GoRpcError):
|
||||
pass
|
||||
|
||||
|
@ -40,8 +42,8 @@ def make_header(method, sequence_id):
|
|||
|
||||
|
||||
class GoRpcRequest(object):
|
||||
header = None # standard fields that route the request on the server side
|
||||
body = None # the actual request object - usually a dictionary
|
||||
header = None # standard fields that route the request on the server side
|
||||
body = None # the actual request object - usually a dictionary
|
||||
|
||||
def __init__(self, header, args):
|
||||
self.header = header
|
||||
|
@ -58,7 +60,7 @@ class GoRpcResponse(object):
|
|||
# 'Seq': sequence_id,
|
||||
# 'Error': error_string}
|
||||
header = None
|
||||
reply = None # the decoded object - usually a dictionary
|
||||
reply = None # the decoded object - usually a dictionary
|
||||
|
||||
@property
|
||||
def error(self):
|
||||
|
@ -70,9 +72,11 @@ class GoRpcResponse(object):
|
|||
|
||||
default_read_buffer_size = 8192
|
||||
|
||||
|
||||
# A single socket wrapper to handle request/response conversation for this
|
||||
# protocol. Internal, use GoRpcClient instead.
|
||||
class _GoRpcConn(object):
|
||||
|
||||
def __init__(self, timeout):
|
||||
self.conn = None
|
||||
# NOTE(msolomon) since the deadlines are approximate in the code, set
|
||||
|
@ -88,7 +92,8 @@ class _GoRpcConn(object):
|
|||
conip = socket.gethostbyname(conhost)
|
||||
except NameError:
|
||||
conip = socket.getaddrinfo(conhost, None)[0][4][0]
|
||||
self.conn = socket.create_connection((conip, int(conport)), self.socket_timeout)
|
||||
self.conn = socket.create_connection(
|
||||
(conip, int(conport)), self.socket_timeout)
|
||||
if parts.scheme == 'https':
|
||||
self.conn = ssl.wrap_socket(self.conn, keyfile=keyfile, certfile=certfile)
|
||||
self.conn.sendall('CONNECT %s HTTP/1.0\n\n' % parts.path)
|
||||
|
@ -101,7 +106,9 @@ class _GoRpcConn(object):
|
|||
continue
|
||||
raise
|
||||
if not d:
|
||||
raise GoRpcError('Unexpected EOF in handshake to %s:%s %s' % (str(conip), str(conport), parts.path))
|
||||
raise GoRpcError(
|
||||
'Unexpected EOF in handshake to %s:%s %s' %
|
||||
(str(conip), str(conport), parts.path))
|
||||
data += d
|
||||
if '\n\n' in data:
|
||||
return
|
||||
|
@ -159,6 +166,7 @@ class _GoRpcConn(object):
|
|||
|
||||
|
||||
class GoRpcClient(object):
|
||||
|
||||
def __init__(self, uri, timeout, certfile=None, keyfile=None):
|
||||
self.uri = uri
|
||||
self.timeout = timeout
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
|
||||
from vtctl import vtctl_client, gorpc_vtctl_client
|
||||
|
||||
|
||||
def init():
|
||||
vtctl_client.register_conn_class('gorpc', gorpc_vtctl_client.GoRpcVtctlClient)
|
||||
vtctl_client.register_conn_class(
|
||||
'gorpc', gorpc_vtctl_client.GoRpcVtctlClient)
|
||||
|
||||
init()
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
import datetime
|
||||
from urlparse import urlparse
|
||||
|
||||
import vtctl_client
|
||||
from vtproto import vtctldata_pb2
|
||||
from vtproto import vtctlservice_pb2
|
||||
import vtctl_client
|
||||
|
||||
|
||||
class GRPCVtctlClient(vtctl_client.VtctlClient):
|
||||
|
|
|
@ -5,13 +5,13 @@
|
|||
# PEP 249 complient db api for Vitess
|
||||
|
||||
apilevel = '2.0'
|
||||
threadsafety = 0 # Threads may not share the module because multi_client is not thread safe
|
||||
# Threads may not share the module because multi_client is not thread safe.
|
||||
threadsafety = 0
|
||||
paramstyle = 'named'
|
||||
|
||||
from vtdb.dbexceptions import *
|
||||
from vtdb.times import Date, Time, Timestamp, DateFromTicks, TimeFromTicks, TimestampFromTicks
|
||||
from vtdb.field_types import Binary
|
||||
from vtdb.field_types import STRING, BINARY, NUMBER, DATETIME, ROWID
|
||||
from vtdb.vtgatev3 import *
|
||||
from vtdb.cursorv3 import *
|
||||
from vtdb.dbexceptions import *
|
||||
from vtdb.field_types import STRING, BINARY, NUMBER, DATETIME, ROWID
|
||||
from vtdb.times import Date, Time, Timestamp, DateFromTicks, TimeFromTicks, TimestampFromTicks
|
||||
from vtdb.vtgatev2 import *
|
||||
from vtdb.vtgatev3 import *
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
from vtdb import dbexceptions
|
||||
|
||||
|
||||
class BaseCursor(object):
|
||||
arraysize = 1
|
||||
lastrowid = None
|
||||
|
@ -39,7 +40,8 @@ class BaseCursor(object):
|
|||
self.connection.rollback()
|
||||
return
|
||||
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self.connection._execute(sql, bind_variables, **kargs)
|
||||
self.results, self.rowcount, self.lastrowid, self.description = (
|
||||
self.connection._execute(sql, bind_variables, **kargs))
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -97,13 +99,16 @@ class BaseCursor(object):
|
|||
raise StopIteration
|
||||
return val
|
||||
|
||||
|
||||
# A simple cursor intended for attaching to a single tablet server.
|
||||
class TabletCursor(BaseCursor):
|
||||
|
||||
def execute(self, sql, bind_variables=None):
|
||||
return self._execute(sql, bind_variables)
|
||||
|
||||
|
||||
class BatchCursor(BaseCursor):
|
||||
|
||||
def __init__(self, connection):
|
||||
# rowset is [(results, rowcount, lastrowid, fields),]
|
||||
self.rowsets = None
|
||||
|
@ -125,12 +130,14 @@ class BatchCursor(BaseCursor):
|
|||
|
||||
# just used for batch items
|
||||
class BatchQueryItem(object):
|
||||
|
||||
def __init__(self, sql, bind_variables, key, keys):
|
||||
self.sql = sql
|
||||
self.bind_variables = bind_variables
|
||||
self.key = key
|
||||
self.keys = keys
|
||||
|
||||
|
||||
class StreamCursor(object):
|
||||
arraysize = 1
|
||||
conversions = None
|
||||
|
@ -149,7 +156,8 @@ class StreamCursor(object):
|
|||
# for instance, a key value for shard mapping
|
||||
def execute(self, sql, bind_variables, **kargs):
|
||||
self.description = None
|
||||
x, y, z, self.description = self.connection._stream_execute(sql, bind_variables, **kargs)
|
||||
x, y, z, self.description = self.connection._stream_execute(
|
||||
sql, bind_variables, **kargs)
|
||||
self.index = 0
|
||||
return 0
|
||||
|
||||
|
|
|
@ -49,10 +49,9 @@ class Cursor(object):
|
|||
self.rollback()
|
||||
return
|
||||
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.tablet_type)
|
||||
self.results, self.rowcount, self.lastrowid, self.description = (
|
||||
self._conn._execute(
|
||||
sql, bind_variables, self.tablet_type))
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -121,7 +120,7 @@ class StreamCursor(Cursor):
|
|||
|
||||
def execute(self, sql, bind_variables, **kargs):
|
||||
self.description = None
|
||||
x, y, z, self.description = self._conn._stream_execute(
|
||||
_, _, _, self.description = self._conn._stream_execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.tablet_type)
|
||||
|
@ -145,7 +144,7 @@ class StreamCursor(Cursor):
|
|||
if self.fetchmany_done:
|
||||
self.fetchmany_done = False
|
||||
return result
|
||||
for i in xrange(size):
|
||||
for _ in xrange(size):
|
||||
row = self.fetchone()
|
||||
if row is None:
|
||||
self.fetchmany_done = True
|
||||
|
|
|
@ -15,7 +15,6 @@ management.
|
|||
# Use of this source code is governed by a BSD-style license that can
|
||||
# be found in the LICENSE file.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
|
||||
from vtdb import dbexceptions
|
||||
|
@ -24,13 +23,13 @@ from vtdb import vtdb_logger
|
|||
from vtdb import vtgatev2
|
||||
|
||||
|
||||
#TODO: verify that these values make sense.
|
||||
# TODO: verify that these values make sense.
|
||||
DEFAULT_CONNECTION_TIMEOUT = 5.0
|
||||
|
||||
__app_read_only_mode_method = lambda:False
|
||||
__app_read_only_mode_method = lambda: False
|
||||
__vtgate_connect_method = vtgatev2.connect
|
||||
#TODO: perhaps make vtgate addrs also a registeration mechanism ?
|
||||
#TODO: add mechansim to refresh vtgate addrs.
|
||||
# TODO: perhaps make vtgate addrs also a registeration mechanism ?
|
||||
# TODO: add mechansim to refresh vtgate addrs.
|
||||
|
||||
|
||||
class DatabaseContext(object):
|
||||
|
@ -44,18 +43,20 @@ class DatabaseContext(object):
|
|||
|
||||
Attributes:
|
||||
lag_tolerant_mode: This directs all replica traffic to batch replicas.
|
||||
This is done for applications that have a OLAP workload and also higher tolerance
|
||||
for replication lag.
|
||||
This is done for applications that have a OLAP workload and also higher
|
||||
tolerance for replication lag.
|
||||
vtgate_addrs: vtgate server endpoints
|
||||
master_access_disabled: Disallow master access for application running in non-master
|
||||
capable cells.
|
||||
master_access_disabled: Disallow master access for application running
|
||||
in non-master` capable cells.
|
||||
event_logger: Logs events and errors of note. Defaults to vtdb_logger.
|
||||
transaction_stack_depth: This allows nesting of transactions and makes
|
||||
commit rpc to VTGate when the outer-most commits.
|
||||
vtgate_connection: Connection to VTGate.
|
||||
"""
|
||||
|
||||
def __init__(self, vtgate_addrs=None, lag_tolerant_mode=False, master_access_disabled=False):
|
||||
def __init__(
|
||||
self, vtgate_addrs=None, lag_tolerant_mode=False,
|
||||
master_access_disabled=False):
|
||||
self.vtgate_addrs = vtgate_addrs
|
||||
self.lag_tolerant_mode = lag_tolerant_mode
|
||||
self.master_access_disabled = master_access_disabled
|
||||
|
@ -83,14 +84,20 @@ class DatabaseContext(object):
|
|||
|
||||
Transactions and some of the consistency guarantees rely on vtgate
|
||||
connections being sticky hence this class caches the connection.
|
||||
|
||||
Returns:
|
||||
A vtgate_connection.
|
||||
"""
|
||||
if self.vtgate_connection is not None and not self.vtgate_connection.is_closed():
|
||||
if (self.vtgate_connection is not None and
|
||||
not self.vtgate_connection.is_closed()):
|
||||
return self.vtgate_connection
|
||||
|
||||
#TODO: the connect method needs to be extended to include query n txn timeouts as well
|
||||
#FIXME: what is the best way of passing other params ?
|
||||
# TODO: the connect method needs to be extended to include query
|
||||
# n txn timeouts as well
|
||||
# FIXME: what is the best way of passing other params ?
|
||||
connect_method = get_vtgate_connect_method()
|
||||
self.vtgate_connection = connect_method(self.vtgate_addrs, self.connection_timeout)
|
||||
self.vtgate_connection = connect_method(
|
||||
self.vtgate_addrs, self.connection_timeout)
|
||||
return self.vtgate_connection
|
||||
|
||||
def degrade_master_read_to_replica(self):
|
||||
|
@ -174,19 +181,31 @@ class DBOperationBase(object):
|
|||
dc: database context object.
|
||||
writable: Indicates whether this is part of write transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, db_context):
|
||||
self.dc = db_context
|
||||
self.writable = False
|
||||
|
||||
def get_cursor(self, **cursor_kargs):
|
||||
"""This returns the create_cursor method of DatabaseContext with
|
||||
the writable attribute from the instance of DBOperationBase's
|
||||
derived classes."""
|
||||
return functools.partial(self.dc.create_cursor, self.writable, **cursor_kargs)
|
||||
"""This returns the create_cursor method of DatabaseContext.
|
||||
|
||||
DatabaseContext has the writable attribute from the instance of
|
||||
DBOperationBase's derived classes.
|
||||
|
||||
Args:
|
||||
**cursor_kargs: Arguments to be passed to create_cursor.
|
||||
|
||||
Returns:
|
||||
The create_cursor method with self.writable as first argument and
|
||||
**cursor_kargs passed through.
|
||||
"""
|
||||
return functools.partial(
|
||||
self.dc.create_cursor, self.writable, **cursor_kargs)
|
||||
|
||||
|
||||
class ReadFromMaster(DBOperationBase):
|
||||
"""Context Manager for reading from master."""
|
||||
|
||||
def __enter__(self):
|
||||
self.dc.read_from_master_setup()
|
||||
return self
|
||||
|
@ -203,6 +222,7 @@ class ReadFromMaster(DBOperationBase):
|
|||
|
||||
class ReadFromReplica(DBOperationBase):
|
||||
"""Context Manager for reading from lag-sensitive or lag-tolerant replica."""
|
||||
|
||||
def __enter__(self):
|
||||
self.dc.read_from_replica_setup()
|
||||
return self
|
||||
|
@ -219,6 +239,7 @@ class ReadFromReplica(DBOperationBase):
|
|||
|
||||
class WriteTransaction(DBOperationBase):
|
||||
"""Context Manager for write transactions."""
|
||||
|
||||
def __enter__(self):
|
||||
self.writable = True
|
||||
self.dc.write_transaction_setup()
|
||||
|
@ -293,6 +314,7 @@ def register_create_vtgate_connection_method(connect_method):
|
|||
global __vtgate_connect_method
|
||||
__vtgate_connect_method = connect_method
|
||||
|
||||
|
||||
def get_vtgate_connect_method():
|
||||
"""Returns the vtgate connection creation method."""
|
||||
global __vtgate_connect_method
|
||||
|
@ -301,6 +323,7 @@ def get_vtgate_connect_method():
|
|||
# The global object is for legacy application.
|
||||
__database_context = None
|
||||
|
||||
|
||||
def open_context(*pargs, **kargs):
|
||||
"""Returns the existing global database context or creates a new one."""
|
||||
global __database_context
|
||||
|
|
|
@ -1,26 +1,24 @@
|
|||
"""Module containing the base class for database classes and decorator for db method.
|
||||
"""Tthe base class for database classes and decorator for db method.
|
||||
|
||||
The base class DBObjectBase is the base class for all other database base classes.
|
||||
It has methods for common database operations like select, insert, update and delete.
|
||||
This module also contains the definition for ShardRouting which is used for determining
|
||||
the routing of a query during cursor creation.
|
||||
The module also has the db_class_method decorator and db_wrapper method which are
|
||||
used for cursor creation and calling the database method.
|
||||
The base class DBObjectBase is the base class for all other database
|
||||
base classes. It has methods for common database operations like
|
||||
select, insert, update and delete. This module also contains the
|
||||
definition for ShardRouting which is used for determining the routing
|
||||
of a query during cursor creation. The module also has the
|
||||
db_class_method decorator and db_wrapper method which are used for
|
||||
cursor creation and calling the database method.
|
||||
"""
|
||||
import functools
|
||||
import struct
|
||||
|
||||
from vtdb import database_context
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import db_validator
|
||||
from vtdb import keyrange
|
||||
from vtdb import keyrange_constants
|
||||
from vtdb import shard_constants
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import sql_builder
|
||||
from vtdb import vtgate_cursor
|
||||
|
||||
|
||||
class __EmptyBindVariables(frozenset):
|
||||
pass
|
||||
pass
|
||||
|
||||
EmptyBindVariables = __EmptyBindVariables()
|
||||
|
||||
|
||||
|
@ -35,6 +33,7 @@ class ShardRouting(object):
|
|||
entity_id_sharding_key_map: this map is used for in clause queries.
|
||||
shard_name: this is used to route queries for custom sharded keyspaces.
|
||||
"""
|
||||
|
||||
def __init__(self, keyspace):
|
||||
# keyspace of the table.
|
||||
self.keyspace = keyspace
|
||||
|
@ -51,9 +50,9 @@ def _is_iterable_container(x):
|
|||
return hasattr(x, '__iter__')
|
||||
|
||||
|
||||
INSERT_KW = "insert"
|
||||
UPDATE_KW = "update"
|
||||
DELETE_KW = "delete"
|
||||
INSERT_KW = 'insert'
|
||||
UPDATE_KW = 'update'
|
||||
DELETE_KW = 'delete'
|
||||
|
||||
|
||||
def is_dml(sql):
|
||||
|
@ -103,8 +102,7 @@ def create_cursor_from_old_cursor(old_cursor, table_class):
|
|||
|
||||
|
||||
def create_stream_cursor_from_cursor(original_cursor):
|
||||
"""
|
||||
This method creates streaming cursor from a regular cursor.
|
||||
"""Creates streaming cursor from a regular cursor.
|
||||
|
||||
Args:
|
||||
original_cursor: Cursor of VTGateCursor type
|
||||
|
@ -114,7 +112,7 @@ def create_stream_cursor_from_cursor(original_cursor):
|
|||
"""
|
||||
if not isinstance(original_cursor, vtgate_cursor.VTGateCursor):
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"Original cursor should be of VTGateCursor type.")
|
||||
'Original cursor should be of VTGateCursor type.')
|
||||
stream_cursor = vtgate_cursor.StreamVTGateCursor(
|
||||
original_cursor._conn, original_cursor.keyspace,
|
||||
original_cursor.tablet_type,
|
||||
|
@ -130,13 +128,14 @@ def create_batch_cursor_from_cursor(original_cursor, writable=False):
|
|||
|
||||
Args:
|
||||
original_cursor: Cursor of VTGateCursor type
|
||||
writable: Bool flag.
|
||||
|
||||
Returns:
|
||||
Returns BatchVTGateCursor that has same attributes as original_cursor.
|
||||
"""
|
||||
if not isinstance(original_cursor, vtgate_cursor.VTGateCursor):
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"Original cursor should be of VTGateCursor type.")
|
||||
'Original cursor should be of VTGateCursor type.')
|
||||
batch_cursor = vtgate_cursor.BatchVTGateCursor(
|
||||
original_cursor._conn,
|
||||
original_cursor.tablet_type,
|
||||
|
@ -145,8 +144,7 @@ def create_batch_cursor_from_cursor(original_cursor, writable=False):
|
|||
|
||||
|
||||
def db_wrapper(method, **decorator_kwargs):
|
||||
"""Decorator that is used to create the appropriate cursor
|
||||
for the table and call the database method with it.
|
||||
"""Decorator to call the database method with the appropriate cursor.
|
||||
|
||||
Args:
|
||||
method: Method to decorate.
|
||||
|
@ -158,7 +156,7 @@ def db_wrapper(method, **decorator_kwargs):
|
|||
@functools.wraps(method)
|
||||
def _db_wrapper(*pargs, **kwargs):
|
||||
table_class = pargs[0]
|
||||
write_method = kwargs.get("write_method", False)
|
||||
write_method = kwargs.get('write_method', False)
|
||||
if not issubclass(table_class, DBObjectBase):
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"table class '%s' is not inherited from DBObjectBase" % table_class)
|
||||
|
@ -171,10 +169,10 @@ def db_wrapper(method, **decorator_kwargs):
|
|||
if write_method:
|
||||
if not cursor.is_writable():
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"Executing dmls on a non-writable cursor is not allowed.")
|
||||
'Executing dmls on a non-writable cursor is not allowed.')
|
||||
if table_class.is_mysql_view:
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"writes disabled on view", table_class)
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'writes disabled on view', table_class)
|
||||
|
||||
if pargs[2:]:
|
||||
return method(table_class, cursor, *pargs[2:], **kwargs)
|
||||
|
@ -185,7 +183,7 @@ def db_wrapper(method, **decorator_kwargs):
|
|||
|
||||
def write_db_class_method(*pargs, **kwargs):
|
||||
"""Used for DML methods. Calls db_class_method."""
|
||||
kwargs["write_method"] = True
|
||||
kwargs['write_method'] = True
|
||||
return db_class_method(*pargs, **kwargs)
|
||||
|
||||
|
||||
|
@ -220,11 +218,9 @@ class DBObjectBase(object):
|
|||
is_mysql_view = False
|
||||
utf8_columns = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_shard_routing(class_, *pargs, **kwargs):
|
||||
"""This method is used to create ShardRouting object which is
|
||||
used for determining routing attributes for the vtgate cursor.
|
||||
"""Create ShardRouting that determines vtgate cursor routing attributes.
|
||||
|
||||
Returns:
|
||||
ShardRouting object.
|
||||
|
@ -232,24 +228,25 @@ class DBObjectBase(object):
|
|||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kwargs):
|
||||
"""This creates the VTGateCursor object which is used to make
|
||||
all the rpc calls to VTGate.
|
||||
def create_vtgate_cursor(
|
||||
class_, vtgate_conn, tablet_type, is_dml, **cursor_kwargs):
|
||||
"""Creates the VTGateCursor which is used to make RPCs to VTGate.
|
||||
|
||||
Args:
|
||||
vtgate_conn: connection to vtgate.
|
||||
tablet_type: tablet type to connect to.
|
||||
is_dml: Makes the cursor writable, enforces appropriate constraints.
|
||||
vtgate_conn: connection to vtgate.
|
||||
tablet_type: tablet type to connect to.
|
||||
is_dml: Makes the cursor writable, enforces appropriate constraints.
|
||||
**cursor_kwargs: More args.
|
||||
|
||||
Returns:
|
||||
VTGateCursor for the query.
|
||||
VTGateCursor for the query.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_column_value_pairs_for_write(class_, **column_values):
|
||||
invalid_columns = db_validator.invalid_utf8_columns(class_.utf8_columns or [],
|
||||
column_values)
|
||||
invalid_columns = db_validator.invalid_utf8_columns(
|
||||
class_.utf8_columns or [], column_values)
|
||||
if invalid_columns:
|
||||
exc = InvalidUtf8DbWrite(class_.table_name, invalid_columns)
|
||||
raise exc
|
||||
|
@ -296,23 +293,24 @@ class DBObjectBase(object):
|
|||
def create_select_query(class_, where_column_value_pairs, columns_list=None,
|
||||
order_by=None, group_by=None, limit=None):
|
||||
if class_.columns_list is None:
|
||||
raise dbexceptions.ProgrammingError("DB class should define columns_list")
|
||||
raise dbexceptions.ProgrammingError('DB class should define columns_list')
|
||||
|
||||
if columns_list is None:
|
||||
columns_list = class_.columns_list
|
||||
|
||||
query, bind_vars = sql_builder.select_by_columns_query(columns_list,
|
||||
class_.table_name,
|
||||
where_column_value_pairs,
|
||||
order_by=order_by,
|
||||
group_by=group_by,
|
||||
limit=limit)
|
||||
query, bind_vars = sql_builder.select_by_columns_query(
|
||||
columns_list,
|
||||
class_.table_name,
|
||||
where_column_value_pairs,
|
||||
order_by=order_by,
|
||||
group_by=group_by,
|
||||
limit=limit)
|
||||
return query, bind_vars
|
||||
|
||||
@write_db_class_method
|
||||
def insert(class_, cursor, **bind_vars):
|
||||
if class_.columns_list is None:
|
||||
raise dbexceptions.ProgrammingError("DB class should define columns_list")
|
||||
raise dbexceptions.ProgrammingError('DB class should define columns_list')
|
||||
|
||||
query, bind_vars = class_.create_insert_query(**bind_vars)
|
||||
cursor.execute(query, bind_vars)
|
||||
|
@ -330,19 +328,21 @@ class DBObjectBase(object):
|
|||
@write_db_class_method
|
||||
def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):
|
||||
if not where_column_value_pairs:
|
||||
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'deleting the whole table is not allowed')
|
||||
|
||||
query, bind_vars = class_.create_delete_query(where_column_value_pairs,
|
||||
limit=limit)
|
||||
cursor.execute(query, bind_vars)
|
||||
if cursor.rowcount == 0:
|
||||
raise dbexceptions.DatabaseError("DB Row not found")
|
||||
raise dbexceptions.DatabaseError('DB Row not found')
|
||||
return cursor.rowcount
|
||||
|
||||
@db_class_method
|
||||
def select_by_columns_streaming(class_, cursor, where_column_value_pairs,
|
||||
columns_list=None, order_by=None, group_by=None,
|
||||
limit=None, fetch_size=100):
|
||||
def select_by_columns_streaming(
|
||||
class_, cursor, where_column_value_pairs,
|
||||
columns_list=None, order_by=None, group_by=None,
|
||||
limit=None, fetch_size=100):
|
||||
query, bind_vars = class_.create_select_query(where_column_value_pairs,
|
||||
columns_list=columns_list,
|
||||
order_by=order_by,
|
||||
|
@ -382,7 +382,7 @@ class DBObjectBase(object):
|
|||
@db_class_method
|
||||
def get_min(class_, cursor):
|
||||
if class_.id_column_name is None:
|
||||
raise dbexceptions.ProgrammingError("id_column_name not set.")
|
||||
raise dbexceptions.ProgrammingError('id_column_name not set.')
|
||||
|
||||
query, bind_vars = sql_builder.build_aggregate_query(
|
||||
class_.table_name, class_.id_column_name, is_asc=True)
|
||||
|
@ -392,7 +392,7 @@ class DBObjectBase(object):
|
|||
@db_class_method
|
||||
def get_max(class_, cursor):
|
||||
if class_.id_column_name is None:
|
||||
raise dbexceptions.ProgrammingError("id_column_name not set.")
|
||||
raise dbexceptions.ProgrammingError('id_column_name not set.')
|
||||
|
||||
query, bind_vars = sql_builder.build_aggregate_query(
|
||||
class_.table_name, class_.id_column_name, is_asc=False)
|
||||
|
|
|
@ -13,9 +13,10 @@ from vtdb import vtgate_cursor
|
|||
class DBObjectCustomSharded(db_object.DBObjectBase):
|
||||
"""Base class for custom-sharded db classes.
|
||||
|
||||
This class is intended to support a custom sharding scheme, where the user
|
||||
controls the routing of their queries by passing in the shard_name
|
||||
explicitly.This provides helper methods for common database access operations.
|
||||
This class is intended to support a custom sharding scheme, where
|
||||
the user controls the routing of their queries by passing in the
|
||||
shard_name explicitly. This provides helper methods for common
|
||||
database access operations.
|
||||
"""
|
||||
keyspace = None
|
||||
sharding = shard_constants.CUSTOM_SHARDED
|
||||
|
@ -25,19 +26,21 @@ class DBObjectCustomSharded(db_object.DBObjectBase):
|
|||
|
||||
@classmethod
|
||||
def create_shard_routing(class_, *pargs, **kwargs):
|
||||
routing = db_object.ShardRouting(keyspace)
|
||||
routing.shard_name = kargs.get('shard_name')
|
||||
routing = db_object.ShardRouting(class_.keyspace)
|
||||
routing.shard_name = kwargs.get('shard_name')
|
||||
if routing.shard_name is None:
|
||||
dbexceptions.InternalError("For custom sharding, shard_name cannot be None.")
|
||||
dbexceptions.InternalError(
|
||||
'For custom sharding, shard_name cannot be None.')
|
||||
|
||||
if (_is_iterable_container(routing.shard_name)
|
||||
and is_dml):
|
||||
raise dbexceptions.InternalError(
|
||||
"Writes are not allowed on multiple shards.")
|
||||
'Writes are not allowed on multiple shards.')
|
||||
return routing
|
||||
|
||||
@classmethod
|
||||
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
|
||||
def create_vtgate_cursor(
|
||||
class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
|
||||
# FIXME:extend VTGateCursor's api to accept shard_names
|
||||
# and allow queries based on that.
|
||||
routing = class_.create_shard_routing(**cursor_kargs)
|
||||
|
|
|
@ -5,26 +5,17 @@ relevant methods. LookupDBObject inherits from DBObjectUnsharded and
|
|||
extends the functionality for getting, creating, updating and deleting
|
||||
the lookup relationship.
|
||||
"""
|
||||
import functools
|
||||
import struct
|
||||
|
||||
from vtdb import db_object
|
||||
from vtdb import db_object_unsharded
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import keyrange
|
||||
from vtdb import keyrange_constants
|
||||
from vtdb import shard_constants
|
||||
from vtdb import vtgate_cursor
|
||||
|
||||
|
||||
class LookupDBObject(db_object_unsharded.DBObjectUnsharded):
|
||||
"""This is an example implementation of lookup class where it is stored
|
||||
in unsharded db.
|
||||
"""
|
||||
"""A example implementation of a lookup class stored in an unsharded db."""
|
||||
|
||||
@classmethod
|
||||
def get(class_, cursor, entity_id_column, entity_id):
|
||||
where_column_value_pairs = [(entity_id_column, entity_id),]
|
||||
rows = class_.select_by_columns(cursor, where_column_value_pairs)
|
||||
rows = class_.select_by_columns(cursor, where_column_value_pairs)
|
||||
return [row.__dict__ for row in rows]
|
||||
|
||||
@classmethod
|
||||
|
@ -35,7 +26,7 @@ class LookupDBObject(db_object_unsharded.DBObjectUnsharded):
|
|||
def update(class_, cursor, sharding_key_column_name, sharding_key,
|
||||
entity_id_column, new_entity_id):
|
||||
where_column_value_pairs = [(sharding_key_column_name, sharding_key),]
|
||||
update_column_value_pairs = [(entity_id_column,new_entity_id),]
|
||||
update_column_value_pairs = [(entity_id_column, new_entity_id),]
|
||||
return class_.update_columns(cursor, where_column_value_pairs,
|
||||
update_column_value_pairs)
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
"""Module containing base classes for range-sharded database objects.
|
||||
|
||||
There are two base classes for tables that live in range-sharded keyspace -
|
||||
1. DBObjectRangeSharded - This should be used for tables that only reference lookup entities
|
||||
but don't create or manage them. Please see examples in test/clientlib_tests/db_class_sharded.py.
|
||||
2. DBObjectEntityRangeSharded - This inherits from DBObjectRangeSharded and is used for tables
|
||||
and also create new lookup relationships.
|
||||
This module also contains helper methods for cursor creation for accessing lookup tables
|
||||
and methods for dml and select for the above mentioned base classes.
|
||||
1. DBObjectRangeSharded - This should be used for tables that only reference
|
||||
lookup entities but don't create or manage them. Please see examples in
|
||||
test/clientlib_tests/db_class_sharded.py.
|
||||
2. DBObjectEntityRangeSharded - This inherits from DBObjectRangeSharded and
|
||||
is used for tables and also create new lookup relationships.
|
||||
This module also contains helper methods for cursor creation for
|
||||
accessing lookup tables and methods for dml and select for the above
|
||||
mentioned base classes.
|
||||
"""
|
||||
import functools
|
||||
import struct
|
||||
|
@ -14,7 +16,6 @@ import struct
|
|||
from vtdb import db_object
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import keyrange
|
||||
from vtdb import keyrange_constants
|
||||
from vtdb import shard_constants
|
||||
from vtdb import sql_builder
|
||||
from vtdb import vtgate_cursor
|
||||
|
@ -29,7 +30,7 @@ pack_keyspace_id = struct.Struct('!Q').pack
|
|||
# This unpacks the keyspace_id so that it can be used
|
||||
# in bind variables.
|
||||
def unpack_keyspace_id(kid):
|
||||
return struct.Struct('!Q').unpack(kid)[0]
|
||||
return struct.Struct('!Q').unpack(kid)[0]
|
||||
|
||||
|
||||
class DBObjectRangeSharded(db_object.DBObjectBase):
|
||||
|
@ -51,7 +52,7 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
# List of columns on the database table. This is used in query construction.
|
||||
columns_list = None
|
||||
|
||||
#FIXME: is this needed ?
|
||||
# FIXME: is this needed ?
|
||||
id_column_name = None
|
||||
|
||||
# sharding_key_column_name defines column name for sharding key for this
|
||||
|
@ -70,7 +71,7 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
column_lookup_name_map = None
|
||||
|
||||
@classmethod
|
||||
def create_shard_routing(class_, *pargs, **kargs):
|
||||
def create_shard_routing(class_, *pargs, **kargs):
|
||||
"""This creates the ShardRouting object based on the kargs.
|
||||
This prunes the routing kargs so as not to interfere with the
|
||||
actual database method.
|
||||
|
@ -79,9 +80,10 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
*pargs: Positional arguments
|
||||
**kargs: Routing key-value params. These are used to determine routing.
|
||||
There are two mutually exclusive mechanisms to indicate routing.
|
||||
1. entity_id_map {"entity_id_column": entity_id_value} where entity_id_column
|
||||
could be the sharding key or a lookup based entity column of this table. This
|
||||
helps determine the keyspace_ids for the cursor.
|
||||
1. entity_id_map {"entity_id_column": entity_id_value} where
|
||||
entity_id_column could be the sharding key or a lookup based entity
|
||||
column of this table. This helps determine the keyspace_ids for the
|
||||
cursor.
|
||||
2. keyrange - This helps determine the keyrange for the cursor.
|
||||
|
||||
Returns:
|
||||
|
@ -91,10 +93,10 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
routing = db_object.ShardRouting(class_.keyspace)
|
||||
entity_id_map = None
|
||||
|
||||
entity_id_map = kargs.get("entity_id_map", None)
|
||||
entity_id_map = kargs.get('entity_id_map', None)
|
||||
if entity_id_map is None:
|
||||
kr = None
|
||||
key_range = kargs.get("keyrange", None)
|
||||
key_range = kargs.get('keyrange', None)
|
||||
if isinstance(key_range, keyrange.KeyRange):
|
||||
kr = key_range
|
||||
else:
|
||||
|
@ -106,36 +108,42 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
|
||||
# entity_id_map is not None
|
||||
if len(entity_id_map) != 1:
|
||||
dbexceptions.ProgrammingError("Invalid entity_id_map '%s'" % entity_id_map)
|
||||
dbexceptions.ProgrammingError(
|
||||
"Invalid entity_id_map '%s'" % entity_id_map)
|
||||
|
||||
entity_id_col = entity_id_map.keys()[0]
|
||||
entity_id = entity_id_map[entity_id_col]
|
||||
|
||||
#TODO: the current api means that if a table doesn't have the sharding key column name
|
||||
# then it cannot pass in sharding key for routing purposes. Will this cause
|
||||
# extra load on lookup db/cache ? This is cleaner from a design perspective.
|
||||
#TODO: the current api means that if a table doesn't have the
|
||||
# sharding key column name then it cannot pass in sharding key for
|
||||
# routing purposes. Will this cause extra load on lookup db/cache
|
||||
# ? This is cleaner from a design perspective.
|
||||
if entity_id_col == class_.sharding_key_column_name:
|
||||
# Routing using sharding key.
|
||||
routing.sharding_key = entity_id
|
||||
if not class_.is_sharding_key_valid(routing.sharding_key):
|
||||
raise dbexceptions.InternalError("Invalid sharding_key %s" % routing.sharding_key)
|
||||
raise dbexceptions.InternalError(
|
||||
'Invalid sharding_key %s' % routing.sharding_key)
|
||||
else:
|
||||
# Routing using lookup based entity.
|
||||
routing.entity_column_name = entity_id_col
|
||||
routing.entity_id_sharding_key_map = class_.lookup_sharding_key_from_entity_id(
|
||||
lookup_cursor_method, entity_id_col, entity_id)
|
||||
routing.entity_id_sharding_key_map = (
|
||||
class_.lookup_sharding_key_from_entity_id(
|
||||
lookup_cursor_method, entity_id_col, entity_id))
|
||||
|
||||
return routing
|
||||
|
||||
@classmethod
|
||||
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
|
||||
def create_vtgate_cursor(
|
||||
class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
|
||||
cursor_method = functools.partial(db_object.create_cursor_from_params,
|
||||
vtgate_conn, tablet_type, False)
|
||||
routing = class_.create_shard_routing(cursor_method, **cursor_kargs)
|
||||
if is_dml:
|
||||
if routing.sharding_key is None or db_object._is_iterable_container(routing.sharding_key):
|
||||
if (routing.sharding_key is None or
|
||||
db_object._is_iterable_container(routing.sharding_key)):
|
||||
dbexceptions.InternalError(
|
||||
"Writes require unique sharding_key")
|
||||
'Writes require unique sharding_key')
|
||||
|
||||
keyspace_ids = None
|
||||
keyranges = None
|
||||
|
@ -151,7 +159,8 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
elif routing.entity_id_sharding_key_map is not None:
|
||||
keyspace_ids = []
|
||||
for sharding_key in routing.entity_id_sharding_key_map.values():
|
||||
keyspace_ids.append(pack_keyspace_id(class_.sharding_key_to_keyspace_id(sharding_key)))
|
||||
keyspace_ids.append(
|
||||
pack_keyspace_id(class_.sharding_key_to_keyspace_id(sharding_key)))
|
||||
elif routing.keyrange:
|
||||
keyranges = [routing.keyrange,]
|
||||
|
||||
|
@ -174,11 +183,14 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
return class_.column_lookup_name_map.get(column_name, column_name)
|
||||
|
||||
@classmethod
|
||||
def lookup_sharding_key_from_entity_id(class_, cursor_method, entity_id_column, entity_id):
|
||||
def lookup_sharding_key_from_entity_id(
|
||||
class_, cursor_method, entity_id_column, entity_id):
|
||||
"""This method is used to map any entity id to sharding key.
|
||||
|
||||
Args:
|
||||
entity_id_column: Non-sharding key indexes that can be used for query routing.
|
||||
cursor_method: Cursor method.
|
||||
entity_id_column: Non-sharding key indexes that can be used for query
|
||||
routing.
|
||||
entity_id: entity id value.
|
||||
|
||||
Returns:
|
||||
|
@ -189,20 +201,22 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
rows = lookup_class.get(cursor_method, entity_lookup_column, entity_id)
|
||||
|
||||
entity_id_sharding_key_map = {}
|
||||
if len(rows) == 0:
|
||||
#return entity_id_sharding_key_map
|
||||
raise dbexceptions.DatabaseError("LookupRow not found")
|
||||
if not rows:
|
||||
# return entity_id_sharding_key_map
|
||||
raise dbexceptions.DatabaseError('LookupRow not found')
|
||||
|
||||
if class_.sharding_key_column_name is not None:
|
||||
sk_lookup_column = class_.get_lookup_column_name(class_.sharding_key_column_name)
|
||||
sk_lookup_column = class_.get_lookup_column_name(
|
||||
class_.sharding_key_column_name)
|
||||
else:
|
||||
# This is needed since the table may not have a sharding key column name
|
||||
# but the lookup map will have it.
|
||||
lookup_column_names = rows[0].keys()
|
||||
if len(lookup_column_names) != 2:
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"lookup table has more than two columns.")
|
||||
sk_lookup_column = list(set(lookup_column_names) - set(list(entity_lookup_column)))[0]
|
||||
'lookup table has more than two columns.')
|
||||
sk_lookup_column = list(
|
||||
set(lookup_column_names) - set(list(entity_lookup_column)))[0]
|
||||
for row in rows:
|
||||
en_id = row[entity_lookup_column]
|
||||
sk = row[sk_lookup_column]
|
||||
|
@ -212,8 +226,8 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
|
||||
@db_object.db_class_method
|
||||
def select_by_ids(class_, cursor, where_column_value_pairs,
|
||||
columns_list=None, order_by=None, group_by=None,
|
||||
limit=None, **kwargs):
|
||||
columns_list=None, order_by=None, group_by=None,
|
||||
limit=None, **kwargs):
|
||||
"""This method is used to perform in-clause queries.
|
||||
|
||||
Such queries can cause vtgate to scatter over multiple shards.
|
||||
|
@ -238,17 +252,20 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
entity_col_name = class_.sharding_key_column_name
|
||||
if db_object._is_iterable_container(cursor.routing.sharding_key):
|
||||
for sk in list(cursor.routing.sharding_key):
|
||||
entity_id_keyspace_id_map[sk] = pack_keyspace_id(class_.sharding_key_to_keyspace_id(sk))
|
||||
entity_id_keyspace_id_map[sk] = pack_keyspace_id(
|
||||
class_.sharding_key_to_keyspace_id(sk))
|
||||
else:
|
||||
sk = cursor.routing.sharding_key
|
||||
entity_id_keyspace_id_map[sk] = pack_keyspace_id(class_.sharding_key_to_keyspace_id(sk))
|
||||
entity_id_keyspace_id_map[sk] = pack_keyspace_id(
|
||||
class_.sharding_key_to_keyspace_id(sk))
|
||||
elif cursor.routing.entity_id_sharding_key_map is not None:
|
||||
# If the in-clause is based on entity column
|
||||
entity_col_name = cursor.routing.entity_column_name
|
||||
for en_id, sk in cursor.routing.entity_id_sharding_key_map.iteritems():
|
||||
entity_id_keyspace_id_map[en_id] = pack_keyspace_id(class_.sharding_key_to_keyspace_id(sk))
|
||||
entity_id_keyspace_id_map[en_id] = pack_keyspace_id(
|
||||
class_.sharding_key_to_keyspace_id(sk))
|
||||
else:
|
||||
dbexceptions.ProgrammingError("Invalid routing method used.")
|
||||
dbexceptions.ProgrammingError('Invalid routing method used.')
|
||||
|
||||
# cursor.routing.entity_column_name is set while creating shard routing.
|
||||
rowcount = cursor.execute_entity_ids(query, bind_vars,
|
||||
|
@ -284,7 +301,7 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
@db_object.db_class_method
|
||||
def insert(class_, cursor, **bind_vars):
|
||||
if class_.columns_list is None:
|
||||
raise dbexceptions.ProgrammingError("DB class should define columns_list")
|
||||
raise dbexceptions.ProgrammingError('DB class should define columns_list')
|
||||
|
||||
keyspace_id = bind_vars.get('keyspace_id', None)
|
||||
if keyspace_id is None:
|
||||
|
@ -331,35 +348,36 @@ class DBObjectRangeSharded(db_object.DBObjectBase):
|
|||
def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):
|
||||
|
||||
if not where_column_value_pairs:
|
||||
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'deleting the whole table is not allowed')
|
||||
|
||||
where_column_value_pairs = class_._add_keyspace_id(
|
||||
unpack_keyspace_id(cursor.keyspace_ids[0]), where_column_value_pairs)
|
||||
|
||||
query, bind_vars = sql_builder.delete_by_columns_query(class_.table_name,
|
||||
where_column_value_pairs,
|
||||
limit=limit)
|
||||
query, bind_vars = sql_builder.delete_by_columns_query(
|
||||
class_.table_name, where_column_value_pairs, limit=limit)
|
||||
cursor.execute(query, bind_vars)
|
||||
if cursor.rowcount == 0:
|
||||
raise dbexceptions.DatabaseError("DB Row not found")
|
||||
raise dbexceptions.DatabaseError('DB Row not found')
|
||||
|
||||
return cursor.rowcount
|
||||
|
||||
|
||||
class DBObjectEntityRangeSharded(DBObjectRangeSharded):
|
||||
"""Base class for sharded tables that also needs to create and manage lookup
|
||||
entities.
|
||||
"""Base class for sharded tables that create and manage lookup entities.
|
||||
|
||||
This provides default implementation of routing helper methods, cursor
|
||||
creation and common database access operations.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_insert_id_from_lookup(class_, cursor_method, entity_id_col, **bind_vars):
|
||||
def get_insert_id_from_lookup(
|
||||
class_, cursor_method, entity_id_col, **bind_vars):
|
||||
"""This method is used to map any entity id to sharding key.
|
||||
|
||||
Args:
|
||||
entity_id_column: Non-sharding key indexes that can be used for query routing.
|
||||
entity_id_column: Non-sharding key indexes that can be used for query
|
||||
routing.
|
||||
entity_id: entity id value.
|
||||
|
||||
Returns:
|
||||
|
@ -376,18 +394,19 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
|
|||
@classmethod
|
||||
def delete_sharding_key_entity_id_lookup(class_, cursor_method,
|
||||
sharding_key):
|
||||
sharding_key_lookup_column = class_.get_lookup_column_name(class_.sharding_key_column_name)
|
||||
sharding_key_lookup_column = class_.get_lookup_column_name(
|
||||
class_.sharding_key_column_name)
|
||||
for lookup_class in class_.entity_id_lookup_map.values():
|
||||
lookup_class.delete(cursor_method,
|
||||
sharding_key_lookup_column,
|
||||
sharding_key)
|
||||
|
||||
|
||||
@classmethod
|
||||
def update_sharding_key_entity_id_lookup(class_, cursor_method,
|
||||
sharding_key, entity_id_column,
|
||||
new_entity_id):
|
||||
sharding_key_lookup_column = class_.get_lookup_column_name(class_.sharding_key_column_name)
|
||||
sharding_key_lookup_column = class_.get_lookup_column_name(
|
||||
class_.sharding_key_column_name)
|
||||
entity_id_lookup_column = class_.get_lookup_column_name(entity_id_column)
|
||||
lookup_class = class_.entity_id_lookup_map[entity_id_column]
|
||||
return lookup_class.update(cursor_method,
|
||||
|
@ -400,28 +419,29 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
|
|||
@db_object.write_db_class_method
|
||||
def insert_primary(class_, cursor, **bind_vars):
|
||||
if class_.columns_list is None:
|
||||
raise dbexceptions.ProgrammingError("DB class should define columns_list")
|
||||
raise dbexceptions.ProgrammingError('DB class should define columns_list')
|
||||
|
||||
query, bind_vars = class_.create_insert_query(**bind_vars)
|
||||
cursor.execute(query, bind_vars)
|
||||
return cursor.lastrowid
|
||||
|
||||
|
||||
@classmethod
|
||||
def insert(class_, cursor_method, **bind_vars):
|
||||
""" This method creates the lookup relationship as well as the insert
|
||||
in the primary table. The creation of the lookup entry also creates the
|
||||
primary key for the row in the primary table.
|
||||
"""Creates the lookup relationship and inserts in the primary table.
|
||||
|
||||
The lookup relationship is determined by class_.column_lookup_name_map and the bind
|
||||
variables passed in. There are two types of entities -
|
||||
1. Table for which the entity that is also the primary sharding key for this keyspace.
|
||||
2. Entity table that creates a new entity and needs to create a lookup between
|
||||
that entity and sharding key.
|
||||
The creation of the lookup entry also creates the primary key for
|
||||
the row in the primary table.
|
||||
|
||||
The lookup relationship is determined by class_.column_lookup_name_map
|
||||
and the bind variables passed in. There are two types of entities -
|
||||
1. Table for which the entity that is also the primary sharding key for
|
||||
this keyspace.
|
||||
2. Entity table that creates a new entity and needs to create a lookup
|
||||
between that entity and sharding key.
|
||||
"""
|
||||
if class_.sharding_key_column_name is None:
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"sharding_key_column_name empty for DBObjectEntityRangeSharded")
|
||||
'sharding_key_column_name empty for DBObjectEntityRangeSharded')
|
||||
|
||||
# Used for insert into class_.table_name
|
||||
new_inserted_key = None
|
||||
|
@ -431,7 +451,7 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
|
|||
if (not class_.entity_id_lookup_map
|
||||
or not isinstance(class_.entity_id_lookup_map, dict)):
|
||||
raise dbexceptions.ProgrammingError(
|
||||
"Invalid entity_id_lookup_map %s" % class_.entity_id_lookup_map)
|
||||
'Invalid entity_id_lookup_map %s' % class_.entity_id_lookup_map)
|
||||
entity_col = class_.entity_id_lookup_map.keys()[0]
|
||||
|
||||
# Create the lookup entry first
|
||||
|
@ -472,7 +492,7 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
|
|||
update_column_value_pairs):
|
||||
sharding_key = cursor.routing.sharding_key
|
||||
if sharding_key is None:
|
||||
raise dbexceptions.ProgrammingError("sharding_key cannot be empty")
|
||||
raise dbexceptions.ProgrammingError('sharding_key cannot be empty')
|
||||
|
||||
# update the primary table first.
|
||||
query, bind_vars = class_.create_update_query(
|
||||
|
@ -498,23 +518,26 @@ class DBObjectEntityRangeSharded(DBObjectRangeSharded):
|
|||
limit=None):
|
||||
sharding_key = cursor.routing.sharding_key
|
||||
if sharding_key is None:
|
||||
raise dbexceptions.ProgrammingError("sharding_key cannot be empty")
|
||||
raise dbexceptions.ProgrammingError('sharding_key cannot be empty')
|
||||
|
||||
if not where_column_value_pairs:
|
||||
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'deleting the whole table is not allowed')
|
||||
|
||||
query, bind_vars = sql_builder.delete_by_columns_query(class_.table_name,
|
||||
where_column_value_pairs,
|
||||
limit=limit)
|
||||
query, bind_vars = sql_builder.delete_by_columns_query(
|
||||
class_.table_name,
|
||||
where_column_value_pairs,
|
||||
limit=limit)
|
||||
cursor.execute(query, bind_vars)
|
||||
if cursor.rowcount == 0:
|
||||
raise dbexceptions.DatabaseError("DB Row not found")
|
||||
raise dbexceptions.DatabaseError('DB Row not found')
|
||||
|
||||
rowcount = cursor.rowcount
|
||||
|
||||
#delete the lookup map.
|
||||
# delete the lookup map.
|
||||
lookup_cursor_method = functools.partial(
|
||||
db_object.create_cursor_from_old_cursor, cursor)
|
||||
class_.delete_sharding_key_entity_id_lookup(lookup_cursor_method, sharding_key)
|
||||
class_.delete_sharding_key_entity_id_lookup(
|
||||
lookup_cursor_method, sharding_key)
|
||||
|
||||
return rowcount
|
||||
|
|
|
@ -4,9 +4,6 @@ DBObjectUnsharded inherits from DBObjectBase, the implementation
|
|||
for the common database operations is defined in DBObjectBase.
|
||||
DBObjectUnsharded defines the cursor creation methods for the same.
|
||||
"""
|
||||
import functools
|
||||
import struct
|
||||
|
||||
from vtdb import db_object
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import keyrange
|
||||
|
@ -31,11 +28,13 @@ class DBObjectUnsharded(db_object.DBObjectBase):
|
|||
@classmethod
|
||||
def create_shard_routing(class_, *pargs, **kwargs):
|
||||
routing = db_object.ShardRouting(class_.keyspace)
|
||||
routing.keyrange = keyrange.KeyRange(keyrange_constants.NON_PARTIAL_KEYRANGE)
|
||||
routing.keyrange = keyrange.KeyRange(
|
||||
keyrange_constants.NON_PARTIAL_KEYRANGE)
|
||||
return routing
|
||||
|
||||
@classmethod
|
||||
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
|
||||
def create_vtgate_cursor(
|
||||
class_, vtgate_conn, tablet_type, is_dml, **cursor_kargs):
|
||||
routing = class_.create_shard_routing(**cursor_kargs)
|
||||
if routing.keyrange is not None:
|
||||
keyranges = [routing.keyrange,]
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from vtdb import dbexceptions
|
||||
|
||||
|
||||
# A simple class to trap and re-export only variables referenced from
|
||||
# the sql statement since bind dictionaries can be *very* noisy. This
|
||||
# is a by-product of converting the DB-API %(name)s syntax to our
|
||||
# :name syntax.
|
||||
class BindVarsProxy(object):
|
||||
|
||||
def __init__(self, bind_vars):
|
||||
self.bind_vars = bind_vars
|
||||
self.accessed_keys = set()
|
||||
|
@ -27,7 +29,7 @@ class BindVarsProxy(object):
|
|||
def prepare_query_bind_vars(query, bind_vars):
|
||||
bind_vars_proxy = BindVarsProxy(bind_vars)
|
||||
try:
|
||||
query = query % bind_vars_proxy
|
||||
query %= bind_vars_proxy
|
||||
except KeyError as e:
|
||||
raise dbexceptions.InterfaceError(e[0], query, bind_vars)
|
||||
|
||||
|
|
|
@ -1,41 +1,53 @@
|
|||
import exceptions
|
||||
|
||||
|
||||
class Error(exceptions.StandardError):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(exceptions.StandardError):
|
||||
pass
|
||||
|
||||
|
||||
class DataError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class Warning(exceptions.StandardError):
|
||||
pass
|
||||
|
||||
|
||||
class InterfaceError(Error):
|
||||
pass
|
||||
|
||||
|
||||
class InternalError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class OperationalError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class ProgrammingError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class NotSupportedError(ProgrammingError):
|
||||
pass
|
||||
|
||||
|
||||
class IntegrityError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class PartialCommitError(IntegrityError):
|
||||
pass
|
||||
|
||||
|
||||
# Below errors are VT specific
|
||||
|
||||
|
||||
# Retry means a simple and immediate reconnect to the same host/port
|
||||
# will likely fix things. This is initiated by a graceful restart on
|
||||
# the server side. In general this can be handled transparently
|
||||
|
@ -51,8 +63,8 @@ class FatalError(OperationalError):
|
|||
pass
|
||||
|
||||
|
||||
# This failure is operational in the sense that we must teardown the connection to
|
||||
# ensure future RPCs are handled correctly.
|
||||
# This failure is operational in the sense that we must teardown the
|
||||
# connection to ensure future RPCs are handled correctly.
|
||||
class TimeoutError(OperationalError):
|
||||
pass
|
||||
|
||||
|
@ -68,4 +80,3 @@ class RequestBacklog(DatabaseError):
|
|||
# ThrottledError is raised when client exceeds allocated quota on the server
|
||||
class ThrottledError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
|
|
@ -32,14 +32,12 @@ VT_VAR_STRING = 253
|
|||
VT_STRING = 254
|
||||
VT_GEOMETRY = 255
|
||||
|
||||
# FIXME(msolomon) intended for MySQL emulation, but seems more dangerous
|
||||
# to keep this around. This doesn't seem to even be used right now.
|
||||
def Binary(x):
|
||||
return array('c', x)
|
||||
|
||||
class DBAPITypeObject:
|
||||
class DBAPITypeObject(object):
|
||||
|
||||
def __init__(self, *values):
|
||||
self.values = values
|
||||
|
||||
def __cmp__(self, other):
|
||||
if other in self.values:
|
||||
return 0
|
||||
|
@ -47,30 +45,34 @@ class DBAPITypeObject:
|
|||
|
||||
|
||||
# FIXME(msolomon) why do we have these values if they aren't referenced?
|
||||
STRING = DBAPITypeObject(VT_ENUM, VT_VAR_STRING, VT_STRING)
|
||||
BINARY = DBAPITypeObject(VT_TINY_BLOB, VT_MEDIUM_BLOB, VT_LONG_BLOB, VT_BLOB)
|
||||
NUMBER = DBAPITypeObject(VT_DECIMAL, VT_TINY, VT_SHORT, VT_LONG, VT_FLOAT, VT_DOUBLE, VT_LONGLONG, VT_INT24, VT_YEAR, VT_NEWDECIMAL)
|
||||
DATETIME = DBAPITypeObject(VT_TIMESTAMP, VT_DATE, VT_TIME, VT_DATETIME, VT_NEWDATE)
|
||||
ROWID = DBAPITypeObject()
|
||||
STRING = DBAPITypeObject(VT_ENUM, VT_VAR_STRING, VT_STRING)
|
||||
BINARY = DBAPITypeObject(VT_TINY_BLOB, VT_MEDIUM_BLOB, VT_LONG_BLOB, VT_BLOB)
|
||||
NUMBER = DBAPITypeObject(
|
||||
VT_DECIMAL, VT_TINY, VT_SHORT, VT_LONG, VT_FLOAT, VT_DOUBLE, VT_LONGLONG,
|
||||
VT_INT24, VT_YEAR, VT_NEWDECIMAL)
|
||||
DATETIME = DBAPITypeObject(
|
||||
VT_TIMESTAMP, VT_DATE, VT_TIME, VT_DATETIME, VT_NEWDATE)
|
||||
ROWID = DBAPITypeObject()
|
||||
|
||||
conversions = {
|
||||
VT_DECIMAL : Decimal,
|
||||
VT_TINY : int,
|
||||
VT_SHORT : int,
|
||||
VT_LONG : long,
|
||||
VT_FLOAT : float,
|
||||
VT_DOUBLE : float,
|
||||
VT_TIMESTAMP : times.DateTimeOrNone,
|
||||
VT_LONGLONG : long,
|
||||
VT_INT24 : int,
|
||||
VT_DATE : times.DateOrNone,
|
||||
VT_TIME : times.TimeDeltaOrNone,
|
||||
VT_DATETIME : times.DateTimeOrNone,
|
||||
VT_YEAR : int,
|
||||
VT_NEWDATE : times.DateOrNone,
|
||||
VT_NEWDECIMAL : Decimal,
|
||||
VT_DECIMAL: Decimal,
|
||||
VT_TINY: int,
|
||||
VT_SHORT: int,
|
||||
VT_LONG: long,
|
||||
VT_FLOAT: float,
|
||||
VT_DOUBLE: float,
|
||||
VT_TIMESTAMP: times.DateTimeOrNone,
|
||||
VT_LONGLONG: long,
|
||||
VT_INT24: int,
|
||||
VT_DATE: times.DateOrNone,
|
||||
VT_TIME: times.TimeDeltaOrNone,
|
||||
VT_DATETIME: times.DateTimeOrNone,
|
||||
VT_YEAR: int,
|
||||
VT_NEWDATE: times.DateOrNone,
|
||||
VT_NEWDECIMAL: Decimal,
|
||||
}
|
||||
|
||||
|
||||
# This is a temporary workaround till we figure out how to support
|
||||
# native lists in our API.
|
||||
class List(list):
|
||||
|
@ -82,6 +84,7 @@ NoneType = type(None)
|
|||
# That doesn't seem dramatically better than __sql_literal__ but it might
|
||||
# be move self-documenting.
|
||||
|
||||
|
||||
def convert_bind_vars(bind_variables):
|
||||
new_vars = {}
|
||||
if bind_variables is None:
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
|
||||
# Copyright 2015, Google Inc. All rights reserved.
|
||||
# Use of this source code is governed by a BSD-style license that can
|
||||
# be found in the LICENSE file.
|
||||
|
||||
from itertools import izip
|
||||
import logging
|
||||
|
||||
from net import gorpc
|
||||
from net import bsonrpc
|
||||
from net import gorpc
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import field_types
|
||||
from vtdb import update_stream
|
||||
|
@ -26,8 +26,8 @@ def _make_row(row, conversions):
|
|||
|
||||
|
||||
class GoRpcUpdateStreamConnection(update_stream.UpdateStreamConnection):
|
||||
"""GoRpcUpdateStreamConnection is the go rpc implementation of
|
||||
UpdateStreamConnection.
|
||||
"""The go rpc implementation of UpdateStreamConnection.
|
||||
|
||||
It is registered as 'gorpc' protocol.
|
||||
"""
|
||||
|
||||
|
@ -61,7 +61,7 @@ class GoRpcUpdateStreamConnection(update_stream.UpdateStreamConnection):
|
|||
"""Note this implementation doesn't honor the timeout."""
|
||||
try:
|
||||
self.client.stream_call('UpdateStream.ServeUpdateStream',
|
||||
{"Position": position})
|
||||
{'Position': position})
|
||||
while True:
|
||||
response = self.client.stream_next()
|
||||
if response is None:
|
||||
|
|
|
@ -3,16 +3,14 @@
|
|||
# be found in the LICENSE file.
|
||||
|
||||
from itertools import izip
|
||||
import logging
|
||||
from urlparse import urlparse
|
||||
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import field_types
|
||||
from vtdb import update_stream
|
||||
|
||||
from vtproto import binlogdata_pb2
|
||||
from vtproto import binlogservice_pb2
|
||||
from vtproto import replicationdata_pb2
|
||||
|
||||
from vtdb import field_types
|
||||
from vtdb import update_stream
|
||||
|
||||
|
||||
def _make_row(row, conversions):
|
||||
converted_row = []
|
||||
|
@ -28,8 +26,8 @@ def _make_row(row, conversions):
|
|||
|
||||
|
||||
class GRPCUpdateStreamConnection(update_stream.UpdateStreamConnection):
|
||||
"""GRPCUpdateStreamConnection is the gRPC implementation of
|
||||
UpdateStreamConnection.
|
||||
"""The gRPC implementation of UpdateStreamConnection.
|
||||
|
||||
It is registered as 'grpc' protocol.
|
||||
"""
|
||||
|
||||
|
@ -50,7 +48,7 @@ class GRPCUpdateStreamConnection(update_stream.UpdateStreamConnection):
|
|||
self.stub = None
|
||||
|
||||
def is_closed(self):
|
||||
return self.stub == None
|
||||
return self.stub is None
|
||||
|
||||
def stream_update(self, position, timeout=3600.0):
|
||||
req = binlogdata_pb2.StreamUpdateRequest(position=position)
|
||||
|
@ -72,13 +70,14 @@ class GRPCUpdateStreamConnection(update_stream.UpdateStreamConnection):
|
|||
rows.append(row)
|
||||
|
||||
try:
|
||||
yield update_stream.StreamEvent(category=int(stream_event.category),
|
||||
table_name=stream_event.table_name,
|
||||
fields=fields,
|
||||
rows=rows,
|
||||
sql=stream_event.sql,
|
||||
timestamp=stream_event.timestamp,
|
||||
transaction_id=stream_event.transaction_id)
|
||||
yield update_stream.StreamEvent(
|
||||
category=int(stream_event.category),
|
||||
table_name=stream_event.table_name,
|
||||
fields=fields,
|
||||
rows=rows,
|
||||
sql=stream_event.sql,
|
||||
timestamp=stream_event.timestamp,
|
||||
transaction_id=stream_event.transaction_id)
|
||||
except GeneratorExit:
|
||||
# if the loop is interrupted for any reason, we need to
|
||||
# cancel the iterator, so we close the RPC connection,
|
||||
|
|
|
@ -34,17 +34,20 @@ class KeyRange(codec.BSONCoding):
|
|||
else:
|
||||
kr = kr.split('-')
|
||||
if not isinstance(kr, tuple) and not isinstance(kr, list) or len(kr) != 2:
|
||||
raise dbexceptions.ProgrammingError("keyrange must be a list or tuple or a '-' separated str %s" % kr)
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'keyrange must be a list or tuple or a '-' separated str %s' % kr)
|
||||
self.Start = kr[0].strip().decode('hex')
|
||||
self.End = kr[1].strip().decode('hex')
|
||||
|
||||
def __str__(self):
|
||||
if self.Start == keyrange_constants.MIN_KEY and self.End == keyrange_constants.MAX_KEY:
|
||||
if (self.Start == keyrange_constants.MIN_KEY and
|
||||
self.End == keyrange_constants.MAX_KEY):
|
||||
return keyrange_constants.NON_PARTIAL_KEYRANGE
|
||||
return '%s-%s' % (self.Start.encode('hex'), self.End.encode('hex'))
|
||||
|
||||
def __repr__(self):
|
||||
if self.Start == keyrange_constants.MIN_KEY and self.End == keyrange_constants.MAX_KEY:
|
||||
if (self.Start == keyrange_constants.MIN_KEY and
|
||||
self.End == keyrange_constants.MAX_KEY):
|
||||
return 'KeyRange(%r)' % keyrange_constants.NON_PARTIAL_KEYRANGE
|
||||
return 'KeyRange(%r-%r)' % (self.Start, self.End)
|
||||
|
||||
|
@ -56,7 +59,7 @@ class KeyRange(codec.BSONCoding):
|
|||
{"Start": start, "End": end}
|
||||
"""
|
||||
|
||||
return {"Start": self.Start, "End": self.End}
|
||||
return {'Start': self.Start, 'End': self.End}
|
||||
|
||||
def bson_init(self, raw_values):
|
||||
"""Bson initialize the object with start and end dict.
|
||||
|
@ -64,5 +67,5 @@ class KeyRange(codec.BSONCoding):
|
|||
Args:
|
||||
raw_values: Dictionary of start and end values for keyrange.
|
||||
"""
|
||||
self.Start = raw_values["Start"]
|
||||
self.End = raw_values["End"]
|
||||
self.Start = raw_values['Start']
|
||||
self.End = raw_values['End']
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
|
||||
# This is the shard name for when the keyrange covers the entire space
|
||||
# for unsharded database.
|
||||
SHARD_ZERO = "0"
|
||||
SHARD_ZERO = '0'
|
||||
|
||||
# Keyrange that spans the entire space, used
|
||||
# for unsharded database.
|
||||
NON_PARTIAL_KEYRANGE = ""
|
||||
NON_PARTIAL_KEYRANGE = ''
|
||||
MIN_KEY = ''
|
||||
MAX_KEY = ''
|
||||
|
||||
KIT_UNSET = ""
|
||||
KIT_UINT64 = "uint64"
|
||||
KIT_BYTES = "bytes"
|
||||
KIT_UNSET = ''
|
||||
KIT_UINT64 = 'uint64'
|
||||
KIT_BYTES = 'bytes'
|
||||
|
|
|
@ -10,9 +10,12 @@ from vtdb import keyrange_constants
|
|||
|
||||
pack_keyspace_id = struct.Struct('!Q').pack
|
||||
|
||||
# Represent the SrvKeyspace object from the toposerver, and provide functions
|
||||
# to extract sharding information from the same.
|
||||
|
||||
class Keyspace(object):
|
||||
"""Represent the SrvKeyspace object from the toposerver.
|
||||
|
||||
Provide functions to extract sharding information from the same.
|
||||
"""
|
||||
name = None
|
||||
partitions = None
|
||||
sharding_col_name = None
|
||||
|
@ -23,8 +26,9 @@ class Keyspace(object):
|
|||
def __init__(self, name, data):
|
||||
self.name = name
|
||||
self.partitions = data.get('Partitions', {})
|
||||
self.sharding_col_name = data.get('ShardingColumnName', "")
|
||||
self.sharding_col_type = data.get('ShardingColumnType', keyrange_constants.KIT_UNSET)
|
||||
self.sharding_col_name = data.get('ShardingColumnName', '')
|
||||
self.sharding_col_type = data.get(
|
||||
'ShardingColumnType', keyrange_constants.KIT_UNSET)
|
||||
self.served_from = data.get('ServedFrom', None)
|
||||
|
||||
def get_shards(self, db_type):
|
||||
|
@ -51,6 +55,12 @@ class Keyspace(object):
|
|||
"""Finds the shard for a keyspace_id.
|
||||
|
||||
WARNING: this only works for KIT_UINT64 keyspace ids.
|
||||
|
||||
Returns:
|
||||
Shard name.
|
||||
|
||||
Raises:
|
||||
ValueError on invalid keyspace_id.
|
||||
"""
|
||||
if not keyspace_id:
|
||||
raise ValueError('keyspace_id is not set')
|
||||
|
@ -64,11 +74,12 @@ class Keyspace(object):
|
|||
shard['KeyRange']['Start'],
|
||||
shard['KeyRange']['End']):
|
||||
return shard['Name']
|
||||
raise ValueError('cannot find shard for keyspace_id %s in %s' % (keyspace_id, shards))
|
||||
raise ValueError(
|
||||
'cannot find shard for keyspace_id %s in %s' % (keyspace_id, shards))
|
||||
|
||||
|
||||
def _shard_contain_kid(pkid, start, end):
|
||||
return start <= pkid and (end == keyrange_constants.MAX_KEY or pkid < end)
|
||||
return start <= pkid and (end == keyrange_constants.MAX_KEY or pkid < end)
|
||||
|
||||
|
||||
def read_keyspace(topo_client, keyspace_name):
|
||||
|
|
|
@ -4,9 +4,9 @@ Different sharding schemes govern different routing strategies
|
|||
that are computed while create the correct cursor.
|
||||
"""
|
||||
|
||||
UNSHARDED = "UNSHARDED"
|
||||
RANGE_SHARDED = "RANGE"
|
||||
CUSTOM_SHARDED = "CUSTOM"
|
||||
UNSHARDED = 'UNSHARDED'
|
||||
RANGE_SHARDED = 'RANGE'
|
||||
CUSTOM_SHARDED = 'CUSTOM'
|
||||
|
||||
TABLET_TYPE_MASTER = 'master'
|
||||
TABLET_TYPE_REPLICA = 'replica'
|
||||
|
|
|
@ -13,7 +13,7 @@ from vtdb import field_types
|
|||
from vtdb import vtdb_logger
|
||||
|
||||
|
||||
_errno_pattern = re.compile('\(errno (\d+)\)')
|
||||
_errno_pattern = re.compile(r'\(errno (\d+)\)')
|
||||
|
||||
|
||||
def handle_app_error(exc_args):
|
||||
|
@ -60,10 +60,13 @@ def convert_exception(exc, *args):
|
|||
return exc
|
||||
|
||||
|
||||
# A simple, direct connection to the vttablet query server.
|
||||
# This is shard-unaware and only handles the most basic communication.
|
||||
# If something goes wrong, this object should be thrown away and a new one instantiated.
|
||||
class TabletConnection(object):
|
||||
"""A simple, direct connection to the vttablet query server.
|
||||
|
||||
This is shard-unaware and only handles the most basic communication.
|
||||
If something goes wrong, this object should be thrown away and a new
|
||||
one instantiated.
|
||||
"""
|
||||
transaction_id = 0
|
||||
session_id = 0
|
||||
_stream_fields = None
|
||||
|
@ -71,7 +74,9 @@ class TabletConnection(object):
|
|||
_stream_result = None
|
||||
_stream_result_index = None
|
||||
|
||||
def __init__(self, addr, tablet_type, keyspace, shard, timeout, user=None, password=None, keyfile=None, certfile=None):
|
||||
def __init__(
|
||||
self, addr, tablet_type, keyspace, shard, timeout, user=None,
|
||||
password=None, keyfile=None, certfile=None):
|
||||
self.addr = addr
|
||||
self.tablet_type = tablet_type
|
||||
self.keyspace = keyspace
|
||||
|
@ -82,7 +87,8 @@ class TabletConnection(object):
|
|||
self.logger_object = vtdb_logger.get_logger()
|
||||
|
||||
def __str__(self):
|
||||
return '<TabletConnection %s %s %s/%s>' % (self.addr, self.tablet_type, self.keyspace, self.shard)
|
||||
return '<TabletConnection %s %s %s/%s>' % (
|
||||
self.addr, self.tablet_type, self.keyspace, self.shard)
|
||||
|
||||
def dial(self):
|
||||
try:
|
||||
|
@ -92,11 +98,13 @@ class TabletConnection(object):
|
|||
# redial will succeed. This is more a hint that you are doing
|
||||
# it wrong and misunderstanding the life cycle of a
|
||||
# TabletConnection.
|
||||
#raise dbexceptions.ProgrammingError('attempting to reuse TabletConnection')
|
||||
# raise dbexceptions.ProgrammingError(
|
||||
# 'attempting to reuse TabletConnection')
|
||||
|
||||
self.client.dial()
|
||||
params = {'Keyspace': self.keyspace, 'Shard': self.shard}
|
||||
response = self.rpc_call_and_extract_error('SqlQuery.GetSessionId', params)
|
||||
response = self.rpc_call_and_extract_error(
|
||||
'SqlQuery.GetSessionId', params)
|
||||
self.session_id = response.reply['SessionId']
|
||||
except gorpc.GoRpcError as e:
|
||||
raise convert_exception(e, str(self))
|
||||
|
@ -163,11 +171,14 @@ class TabletConnection(object):
|
|||
raise convert_exception(e, str(self))
|
||||
|
||||
def rpc_call_and_extract_error(self, method_name, request):
|
||||
"""Makes an RPC call, and extracts any app error that's embedded in the reply.
|
||||
"""Makes an RPC, extracts any app error that's embedded in the reply.
|
||||
|
||||
Args:
|
||||
method_name - RPC method name, as a string, to call
|
||||
request - request to send to the RPC method call
|
||||
method_name: RPC method name, as a string, to call.
|
||||
request: Request to send to the RPC method call.
|
||||
|
||||
Returns:
|
||||
Response from RPC.
|
||||
|
||||
Raises:
|
||||
gorpc.AppError if there is an app error embedded in the reply
|
||||
|
@ -272,11 +283,13 @@ class TabletConnection(object):
|
|||
reply = first_response.reply
|
||||
if reply.get('Err'):
|
||||
self.__drain_conn_after_streaming_app_error()
|
||||
raise gorpc.AppError(reply['Err'].get('Message', 'Missing error message'))
|
||||
raise gorpc.AppError(reply['Err'].get(
|
||||
'Message', 'Missing error message'))
|
||||
|
||||
for field in reply['Fields']:
|
||||
self._stream_fields.append((field['Name'], field['Type']))
|
||||
self._stream_conversions.append(field_types.conversions.get(field['Type']))
|
||||
self._stream_conversions.append(
|
||||
field_types.conversions.get(field['Type']))
|
||||
except gorpc.GoRpcError as e:
|
||||
self.logger_object.log_private_data(bind_variables)
|
||||
raise convert_exception(e, str(self), sql)
|
||||
|
@ -305,11 +318,13 @@ class TabletConnection(object):
|
|||
reply = first_response.reply
|
||||
if reply.get('Err'):
|
||||
self.__drain_conn_after_streaming_app_error()
|
||||
raise gorpc.AppError(reply['Err'].get('Message', 'Missing error message'))
|
||||
raise gorpc.AppError(reply['Err'].get(
|
||||
'Message', 'Missing error message'))
|
||||
|
||||
for field in reply['Fields']:
|
||||
self._stream_fields.append((field['Name'], field['Type']))
|
||||
self._stream_conversions.append(field_types.conversions.get(field['Type']))
|
||||
self._stream_conversions.append(
|
||||
field_types.conversions.get(field['Type']))
|
||||
except gorpc.GoRpcError as e:
|
||||
self.logger_object.log_private_data(bind_variables)
|
||||
raise convert_exception(e, str(self), sql)
|
||||
|
@ -324,7 +339,7 @@ class TabletConnection(object):
|
|||
return None
|
||||
|
||||
# See if we need to read more or whether we just pop the next row.
|
||||
if self._stream_result is None :
|
||||
if self._stream_result is None:
|
||||
try:
|
||||
self._stream_result = self.client.stream_next()
|
||||
if self._stream_result is None:
|
||||
|
@ -332,14 +347,17 @@ class TabletConnection(object):
|
|||
return None
|
||||
if self._stream_result.reply.get('Err'):
|
||||
self.__drain_conn_after_streaming_app_error()
|
||||
raise gorpc.AppError(self._stream_result.reply['Err'].get('Message', 'Missing error message'))
|
||||
raise gorpc.AppError(self._stream_result.reply['Err'].get(
|
||||
'Message', 'Missing error message'))
|
||||
except gorpc.GoRpcError as e:
|
||||
raise convert_exception(e, str(self))
|
||||
except:
|
||||
logging.exception('gorpc low-level error')
|
||||
raise
|
||||
|
||||
row = tuple(_make_row(self._stream_result.reply['Rows'][self._stream_result_index], self._stream_conversions))
|
||||
row = tuple(
|
||||
_make_row(self._stream_result.reply['Rows'][self._stream_result_index],
|
||||
self._stream_conversions))
|
||||
# If we are reading the last row, set us up to read more data.
|
||||
self._stream_result_index += 1
|
||||
if self._stream_result_index == len(self._stream_result.reply['Rows']):
|
||||
|
@ -351,21 +369,25 @@ class TabletConnection(object):
|
|||
def __drain_conn_after_streaming_app_error(self):
|
||||
"""Drains the connection of all incoming streaming packets (ignoring them).
|
||||
|
||||
This is necessary for streaming calls which return application errors inside
|
||||
the RPC response (instead of through the usual GoRPC error return).
|
||||
This is because GoRPC always expects the last packet to be an error; either
|
||||
the usual GoRPC application error return, or a special "end-of-stream" error.
|
||||
This is necessary for streaming calls which return application
|
||||
errors inside the RPC response (instead of through the usual GoRPC
|
||||
error return). This is because GoRPC always expects the last
|
||||
packet to be an error; either the usual GoRPC application error
|
||||
return, or a special "end-of-stream" error.
|
||||
|
||||
If an application error is returned with the RPC response, there will still be
|
||||
at least one more packet coming, as GoRPC has not seen anything that it
|
||||
considers to be an error. If the connection is not drained of this last
|
||||
packet, future reads from the wire will be off by one and will return errors.
|
||||
If an application error is returned with the RPC response, there
|
||||
will still be at least one more packet coming, as GoRPC has not
|
||||
seen anything that it considers to be an error. If the connection
|
||||
is not drained of this last packet, future reads from the wire
|
||||
will be off by one and will return errors.
|
||||
"""
|
||||
next_result = self.client.stream_next()
|
||||
if next_result is not None:
|
||||
self.client.close()
|
||||
raise gorpc.GoRpcError("Connection should only have one packet remaining"
|
||||
" after streaming app error in RPC response.")
|
||||
raise gorpc.GoRpcError(
|
||||
'Connection should only have one packet remaining'
|
||||
' after streaming app error in RPC response.')
|
||||
|
||||
|
||||
def _make_row(row, conversions):
|
||||
converted_row = []
|
||||
|
|
|
@ -4,7 +4,10 @@
|
|||
#
|
||||
# Use Python datetime module to handle date and time columns.
|
||||
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from datetime import date
|
||||
from datetime import datetime
|
||||
from datetime import time
|
||||
from datetime import timedelta
|
||||
from math import modf
|
||||
from time import localtime
|
||||
|
||||
|
@ -17,18 +20,22 @@ Timestamp = datetime
|
|||
DateTimeDeltaType = timedelta
|
||||
DateTimeType = datetime
|
||||
|
||||
|
||||
# Convert UNIX ticks into a date instance.
|
||||
def DateFromTicks(ticks):
|
||||
return date(*localtime(ticks)[:3])
|
||||
|
||||
|
||||
# Convert UNIX ticks into a time instance.
|
||||
def TimeFromTicks(ticks):
|
||||
return time(*localtime(ticks)[3:6])
|
||||
|
||||
|
||||
# Convert UNIX ticks into a datetime instance.
|
||||
def TimestampFromTicks(ticks):
|
||||
return datetime(*localtime(ticks)[:6])
|
||||
|
||||
|
||||
def DateTimeOrNone(s):
|
||||
if ' ' in s:
|
||||
sep = ' '
|
||||
|
@ -39,34 +46,45 @@ def DateTimeOrNone(s):
|
|||
|
||||
try:
|
||||
d, t = s.split(sep, 1)
|
||||
return datetime(*[ int(x) for x in d.split('-')+t.split(':') ])
|
||||
except:
|
||||
return datetime(*[int(x) for x in d.split('-')+t.split(':')])
|
||||
except Exception:
|
||||
return DateOrNone(s)
|
||||
|
||||
|
||||
def TimeDeltaOrNone(s):
|
||||
try:
|
||||
h, m, s = s.split(':')
|
||||
td = timedelta(hours=int(h), minutes=int(m), seconds=int(float(s)), microseconds=int(modf(float(s))[0]*1000000))
|
||||
td = timedelta(
|
||||
hours=int(h), minutes=int(m), seconds=int(float(s)),
|
||||
microseconds=int(modf(float(s))[0]*1000000))
|
||||
if h < 0:
|
||||
return -td
|
||||
else:
|
||||
return td
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def TimeOrNone(s):
|
||||
try:
|
||||
h, m, s = s.split(':')
|
||||
return time(hour=int(h), minute=int(m), second=int(float(s)), microsecond=int(modf(float(s))[0]*1000000))
|
||||
except:
|
||||
return time(
|
||||
hour=int(h), minute=int(m), second=int(float(s)),
|
||||
microsecond=int(modf(float(s))[0]*1000000))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def DateOrNone(s):
|
||||
try: return date(*[ int(x) for x in s.split('-',2)])
|
||||
except: return None
|
||||
try:
|
||||
return date(*[int(x) for x in s.split('-', 2)])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def DateToString(d):
|
||||
return d.isoformat()
|
||||
|
||||
|
||||
def DateTimeToString(dt):
|
||||
return dt.isoformat(' ')
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
import random
|
||||
|
||||
from zk import zkocc
|
||||
from vtdb import topology
|
||||
from vtdb import vtdb_logger
|
||||
from zk import zkocc
|
||||
|
||||
|
||||
class VTConnParams(object):
|
||||
|
@ -29,9 +29,10 @@ class VTConnParams(object):
|
|||
self.password = password
|
||||
|
||||
|
||||
def get_db_params_for_tablet_conn(topo_client, keyspace_name, shard, db_type, timeout, user, password):
|
||||
def get_db_params_for_tablet_conn(
|
||||
topo_client, keyspace_name, shard, db_type, timeout, user, password):
|
||||
db_params_list = []
|
||||
db_key = "%s.%s.%s:vt" % (keyspace_name, shard, db_type)
|
||||
db_key = '%s.%s.%s:vt' % (keyspace_name, shard, db_type)
|
||||
# This will read the cached keyspace.
|
||||
keyspace_object = topology.get_keyspace(keyspace_name)
|
||||
|
||||
|
@ -44,17 +45,20 @@ def get_db_params_for_tablet_conn(topo_client, keyspace_name, shard, db_type, ti
|
|||
keyspace_name = new_keyspace
|
||||
|
||||
try:
|
||||
end_points_data = topo_client.get_end_points('local', keyspace_name, shard, db_type)
|
||||
end_points_data = topo_client.get_end_points(
|
||||
'local', keyspace_name, shard, db_type)
|
||||
except zkocc.ZkOccError as e:
|
||||
vtdb_logger.get_logger().topo_zkocc_error('do data', db_key, e)
|
||||
return []
|
||||
except Exception as e:
|
||||
vtdb_logger.get_logger().topo_exception('failed to get or parse topo data', db_key, e)
|
||||
vtdb_logger.get_logger().topo_exception(
|
||||
'failed to get or parse topo data', db_key, e)
|
||||
return []
|
||||
|
||||
host_port_list = []
|
||||
if 'Entries' not in end_points_data:
|
||||
vtdb_logger.get_logger().topo_exception('topo server returned: ' + str(end_points_data), db_key, e)
|
||||
vtdb_logger.get_logger().topo_exception(
|
||||
'topo server returned: ' + str(end_points_data), db_key, e)
|
||||
raise Exception('zkocc returned: %s' % str(end_points_data))
|
||||
for entry in end_points_data['Entries']:
|
||||
if 'vt' in entry['PortMap']:
|
||||
|
@ -63,6 +67,8 @@ def get_db_params_for_tablet_conn(topo_client, keyspace_name, shard, db_type, ti
|
|||
random.shuffle(host_port_list)
|
||||
|
||||
for host, port in host_port_list:
|
||||
vt_params = VTConnParams(keyspace_name, shard, db_type, "%s:%s" % (host, port), timeout, user, password).__dict__
|
||||
vt_params = VTConnParams(
|
||||
keyspace_name, shard, db_type, '%s:%s' %
|
||||
(host, port), timeout, user, password).__dict__
|
||||
db_params_list.append(vt_params)
|
||||
return db_params_list
|
||||
|
|
|
@ -25,7 +25,9 @@ from zk import zkocc
|
|||
|
||||
|
||||
# keeps a global version of the topology
|
||||
# This is a map of keyspace_name: (keyspace object, time when keyspace was last fetched)
|
||||
# This is a map of keyspace_name: (keyspace object, time when keyspace
|
||||
# was last fetched)
|
||||
#
|
||||
# eg - {'keyspace_name': (keyspace_object, time_of_last_fetch)}
|
||||
# keyspace object is defined at py/vtdb/keyspace.py:Keyspace
|
||||
__keyspace_map = {}
|
||||
|
@ -100,7 +102,7 @@ def read_topology(zkocc_client, read_fqdb_keys=True):
|
|||
db_keys = []
|
||||
keyspace_list = zkocc_client.get_srv_keyspace_names('local')
|
||||
# validate step
|
||||
if len(keyspace_list) == 0:
|
||||
if not keyspace_list:
|
||||
vtdb_logger.get_logger().topo_empty_keyspace_list()
|
||||
raise Exception('zkocc returned empty keyspace list')
|
||||
for keyspace_name in keyspace_list:
|
||||
|
@ -140,10 +142,12 @@ def get_host_port_by_name(topo_client, db_key):
|
|||
vtdb_logger.get_logger().topo_zkocc_error('do data', db_key, e)
|
||||
return []
|
||||
except Exception as e:
|
||||
vtdb_logger.get_logger().topo_exception('failed to get or parse topo data', db_key, e)
|
||||
vtdb_logger.get_logger().topo_exception(
|
||||
'failed to get or parse topo data', db_key, e)
|
||||
return []
|
||||
if 'Entries' not in data:
|
||||
vtdb_logger.get_logger().topo_exception('topo server returned: ' + str(data), db_key, e)
|
||||
vtdb_logger.get_logger().topo_exception(
|
||||
'topo server returned: ' + str(data), db_key, e)
|
||||
raise Exception('zkocc returned: %s' % str(data))
|
||||
|
||||
host_port_list = []
|
||||
|
|
|
@ -18,19 +18,21 @@ def register_conn_class(protocol, c):
|
|||
|
||||
|
||||
def connect(protocol, *pargs, **kargs):
|
||||
"""connect will return a dialed UpdateStreamConnection connection to
|
||||
an update stream server.
|
||||
"""Return a dialed UpdateStreamConnection to an update stream server.
|
||||
|
||||
Args:
|
||||
protocol: the registered protocol to use.
|
||||
args: passed to the registered protocol __init__ method.
|
||||
protocol: The registered protocol to use.
|
||||
*pargs: Passed to the registered protocol __init__ method.
|
||||
**kargs: Passed to the registered protocol __init__ method.
|
||||
|
||||
Returns:
|
||||
A dialed UpdateStreamConnection.
|
||||
|
||||
Raises:
|
||||
ValueError: On bad protocol.
|
||||
"""
|
||||
if not protocol in update_stream_conn_classes:
|
||||
raise Exception('Unknown update stream protocol', protocol)
|
||||
if protocol not in update_stream_conn_classes:
|
||||
raise ValueError('Unknown update stream protocol', protocol)
|
||||
conn = update_stream_conn_classes[protocol](*pargs, **kargs)
|
||||
conn.dial()
|
||||
return conn
|
||||
|
@ -59,11 +61,10 @@ class StreamEvent(object):
|
|||
|
||||
|
||||
class UpdateStreamConnection(object):
|
||||
"""UpdateStreamConnection is the interface for the update stream
|
||||
client implementations.
|
||||
All implementations must implement all these methods.
|
||||
If something goes wrong with the connection, this object will be thrown out.
|
||||
"""Te interface for the update stream client implementations.
|
||||
|
||||
All implementations must implement all these methods. If something
|
||||
goes wrong with the connection, this object will be thrown out.
|
||||
"""
|
||||
|
||||
def __init__(self, addr, timeout):
|
||||
|
|
|
@ -38,16 +38,17 @@ def reconnect(method):
|
|||
|
||||
if attempt >= self.max_attempts or self.in_txn:
|
||||
self.close()
|
||||
vtdb_logger.get_logger().vtclient_exception(self.keyspace, self.shard, self.db_type, e)
|
||||
vtdb_logger.get_logger().vtclient_exception(
|
||||
self.keyspace, self.shard, self.db_type, e)
|
||||
raise dbexceptions.FatalError(*e.args)
|
||||
if method.__name__ == 'begin':
|
||||
time.sleep(BEGIN_RECONNECT_DELAY)
|
||||
else:
|
||||
time.sleep(RECONNECT_DELAY)
|
||||
logging.info("Attempting to reconnect, %d", attempt)
|
||||
logging.info('Attempting to reconnect, %d', attempt)
|
||||
self.close()
|
||||
self.connect()
|
||||
logging.info("Successfully reconnected to %s", str(self.conn))
|
||||
logging.info('Successfully reconnected to %s', str(self.conn))
|
||||
return _run_with_reconnect
|
||||
|
||||
|
||||
|
@ -89,11 +90,12 @@ class VtOCCConnection(object):
|
|||
try:
|
||||
return self._connect()
|
||||
except dbexceptions.OperationalError as e:
|
||||
vtdb_logger.get_logger().vtclient_exception(self.keyspace, self.shard, self.db_type, e)
|
||||
vtdb_logger.get_logger().vtclient_exception(
|
||||
self.keyspace, self.shard, self.db_type, e)
|
||||
raise
|
||||
|
||||
def _connect(self):
|
||||
db_key = "%s.%s.%s" % (self.keyspace, self.shard, self.db_type)
|
||||
db_key = '%s.%s.%s' % (self.keyspace, self.shard, self.db_type)
|
||||
db_params_list = get_vt_connection_params_list(self.topo_client,
|
||||
self.keyspace,
|
||||
self.shard,
|
||||
|
@ -104,7 +106,8 @@ class VtOCCConnection(object):
|
|||
if not db_params_list:
|
||||
# no valid end-points were found, re-read the keyspace
|
||||
self.resolve_topology()
|
||||
raise dbexceptions.OperationalError("empty db params list - no db instance available for key %s" % db_key)
|
||||
raise dbexceptions.OperationalError(
|
||||
'empty db params list - no db instance available for key %s' % db_key)
|
||||
db_exception = None
|
||||
host_addr = None
|
||||
# no retries here, since there is a higher level retry with reconnect.
|
||||
|
@ -124,7 +127,7 @@ class VtOCCConnection(object):
|
|||
self.resolve_topology()
|
||||
|
||||
raise dbexceptions.OperationalError(
|
||||
'unable to create vt connection', db_key, host_addr, db_exception)
|
||||
'unable to create vt connection', db_key, host_addr, db_exception)
|
||||
|
||||
def cursor(self, cursorclass=None, **kargs):
|
||||
return (cursorclass or self.cursorclass)(self, **kargs)
|
||||
|
@ -160,12 +163,14 @@ class VtOCCConnection(object):
|
|||
sane_sql_list = []
|
||||
sane_bind_vars_list = []
|
||||
for sql, bind_variables in zip(sql_list, bind_variables_list):
|
||||
sane_sql, sane_bind_vars = dbapi.prepare_query_bind_vars(sql, bind_variables)
|
||||
sane_sql, sane_bind_vars = dbapi.prepare_query_bind_vars(
|
||||
sql, bind_variables)
|
||||
sane_sql_list.append(sane_sql)
|
||||
sane_bind_vars_list.append(sane_bind_vars)
|
||||
|
||||
try:
|
||||
result = self.conn._execute_batch(sane_sql_list, sane_bind_vars_list, as_transaction)
|
||||
result = self.conn._execute_batch(
|
||||
sane_sql_list, sane_bind_vars_list, as_transaction)
|
||||
except dbexceptions.IntegrityError as e:
|
||||
vtdb_logger.get_logger().integrity_error(e)
|
||||
raise
|
||||
|
|
|
@ -21,7 +21,8 @@ class VtdbLogger(object):
|
|||
|
||||
# topo_keyspace_fetch is called when we successfully get a SrvKeyspace object.
|
||||
def topo_keyspace_fetch(self, keyspace_name, topo_rtt):
|
||||
logging.info("Fetched keyspace %s from topo_client in %f secs", keyspace_name, topo_rtt)
|
||||
logging.info('Fetched keyspace %s from topo_client in %f secs',
|
||||
keyspace_name, topo_rtt)
|
||||
|
||||
# topo_empty_keyspace_list is called when we get an empty list of
|
||||
# keyspaces from topo server.
|
||||
|
@ -32,7 +33,7 @@ class VtdbLogger(object):
|
|||
# when reading a keyspace. This is within an exception handler.
|
||||
def topo_bad_keyspace_data(self, keyspace_name):
|
||||
logging.exception('error getting or parsing keyspace data for %s',
|
||||
keyspace_name)
|
||||
keyspace_name)
|
||||
|
||||
# topo_zkocc_error is called whenever we get a zkocc.ZkOccError
|
||||
# when trying to resolve an endpoint.
|
||||
|
@ -72,7 +73,7 @@ class VtdbLogger(object):
|
|||
logging.warning('vtgatev2_exception: %s', e)
|
||||
|
||||
def log_private_data(self, private_data):
|
||||
logging.info("Additional exception data %s", private_data)
|
||||
logging.info('Additional exception data %s', private_data)
|
||||
|
||||
|
||||
# registration mechanism for VtdbLogger
|
||||
|
|
|
@ -3,14 +3,18 @@
|
|||
# be found in the LICENSE file.
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import re
|
||||
|
||||
from vtdb import cursor
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import keyrange_constants
|
||||
|
||||
|
||||
write_sql_pattern = re.compile('\s*(insert|update|delete)', re.IGNORECASE)
|
||||
write_sql_pattern = re.compile(r'\s*(insert|update|delete)', re.IGNORECASE)
|
||||
|
||||
|
||||
def ascii_lower(string):
|
||||
"""Lower-case, but only in the ASCII range."""
|
||||
return string.encode('utf8').lower().decode('utf8')
|
||||
|
||||
|
||||
class VTGateCursor(object):
|
||||
|
@ -28,7 +32,9 @@ class VTGateCursor(object):
|
|||
_writable = None
|
||||
routing = None
|
||||
|
||||
def __init__(self, connection, keyspace, tablet_type, keyspace_ids=None, keyranges=None, writable=False):
|
||||
def __init__(
|
||||
self, connection, keyspace, tablet_type, keyspace_ids=None,
|
||||
keyranges=None, writable=False):
|
||||
self._conn = connection
|
||||
self.keyspace = keyspace
|
||||
self.tablet_type = tablet_type
|
||||
|
@ -74,20 +80,22 @@ class VTGateCursor(object):
|
|||
return
|
||||
|
||||
write_query = bool(write_sql_pattern.match(sql))
|
||||
# NOTE: This check may also be done at high-layers but adding it here for completion.
|
||||
# NOTE: This check may also be done at high-layers but adding it
|
||||
# here for completion.
|
||||
if write_query:
|
||||
if not self.is_writable():
|
||||
raise dbexceptions.DatabaseError('DML on a non-writable cursor', sql)
|
||||
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges,
|
||||
not_in_transaction=(not self.is_writable()),
|
||||
effective_caller_id=effective_caller_id)
|
||||
self.results, self.rowcount, self.lastrowid, self.description = (
|
||||
self._conn._execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges,
|
||||
not_in_transaction=(not self.is_writable()),
|
||||
effective_caller_id=effective_caller_id))
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -102,21 +110,22 @@ class VTGateCursor(object):
|
|||
# This is by definition a scatter query, so raise exception.
|
||||
write_query = bool(write_sql_pattern.match(sql))
|
||||
if write_query:
|
||||
raise dbexceptions.DatabaseError('execute_entity_ids is not allowed for write queries')
|
||||
raise dbexceptions.DatabaseError(
|
||||
'execute_entity_ids is not allowed for write queries')
|
||||
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute_entity_ids(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
entity_keyspace_id_map,
|
||||
entity_column_name,
|
||||
not_in_transaction=(not self.is_writable()),
|
||||
effective_caller_id=effective_caller_id)
|
||||
self.results, self.rowcount, self.lastrowid, self.description = (
|
||||
self._conn._execute_entity_ids(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
entity_keyspace_id_map,
|
||||
entity_column_name,
|
||||
not_in_transaction=(not self.is_writable()),
|
||||
effective_caller_id=effective_caller_id))
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
||||
def fetchone(self):
|
||||
if self.results is None:
|
||||
raise dbexceptions.ProgrammingError('fetch called before execute')
|
||||
|
@ -159,7 +168,8 @@ class VTGateCursor(object):
|
|||
# sort the rows and then trim off the prepended sort columns
|
||||
|
||||
if sort_columns:
|
||||
sorted_rows = list(sort_row_list_by_columns(self.fetchall(), sort_columns, desc_columns))[:limit]
|
||||
sorted_rows = list(sort_row_list_by_columns(
|
||||
self.fetchall(), sort_columns, desc_columns))[:limit]
|
||||
else:
|
||||
sorted_rows = itertools.islice(self.fetchall(), limit)
|
||||
neutered_rows = [row[len(order_by_columns):] for row in sorted_rows]
|
||||
|
@ -169,6 +179,7 @@ class VTGateCursor(object):
|
|||
raise dbexceptions.NotSupportedError
|
||||
|
||||
def executemany(self, *pargs):
|
||||
_ = pargs
|
||||
raise dbexceptions.NotSupportedError
|
||||
|
||||
def nextset(self):
|
||||
|
@ -203,6 +214,7 @@ class BatchVTGateCursor(VTGateCursor):
|
|||
This only supports keyspace_ids right now since that is what
|
||||
the underlying vtgate server supports.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, tablet_type, writable=False):
|
||||
# rowset is [(results, rowcount, lastrowid, fields),]
|
||||
self.rowsets = None
|
||||
|
@ -210,7 +222,7 @@ class BatchVTGateCursor(VTGateCursor):
|
|||
self.bind_vars_list = []
|
||||
self.keyspace_list = []
|
||||
self.keyspace_ids_list = []
|
||||
VTGateCursor.__init__(self, connection, "", tablet_type, writable=writable)
|
||||
VTGateCursor.__init__(self, connection, '', tablet_type, writable=writable)
|
||||
|
||||
def execute(self, sql, bind_variables, keyspace, keyspace_ids):
|
||||
self.query_list.append(sql)
|
||||
|
@ -240,8 +252,12 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
index = None
|
||||
fetchmany_done = False
|
||||
|
||||
def __init__(self, connection, keyspace, tablet_type, keyspace_ids=None, keyranges=None, writable=False):
|
||||
VTGateCursor.__init__(self, connection, keyspace, tablet_type, keyspace_ids=keyspace_ids, keyranges=keyranges)
|
||||
def __init__(
|
||||
self, connection, keyspace, tablet_type, keyspace_ids=None,
|
||||
keyranges=None, writable=False):
|
||||
VTGateCursor.__init__(
|
||||
self, connection, keyspace, tablet_type, keyspace_ids=keyspace_ids,
|
||||
keyranges=keyranges)
|
||||
|
||||
# pass kargs here in case higher level APIs need to push more data through
|
||||
# for instance, a key value for shard mapping
|
||||
|
@ -250,7 +266,7 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
raise dbexceptions.ProgrammingError('Streaming query cannot be writable')
|
||||
|
||||
self.description = None
|
||||
x, y, z, self.description = self._conn._stream_execute(
|
||||
_, _, _, self.description = self._conn._stream_execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
|
@ -279,7 +295,7 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
if self.fetchmany_done:
|
||||
self.fetchmany_done = False
|
||||
return result
|
||||
for i in xrange(size):
|
||||
for _ in xrange(size):
|
||||
row = self.fetchone()
|
||||
if row is None:
|
||||
self.fetchmany_done = True
|
||||
|
@ -327,7 +343,8 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
|
||||
# assumes the leading columns are used for sorting
|
||||
def sort_row_list_by_columns(row_list, sort_columns=(), desc_columns=()):
|
||||
for column_index, column_name in reversed([x for x in enumerate(sort_columns)]):
|
||||
for column_index, column_name in reversed(
|
||||
[x for x in enumerate(sort_columns)]):
|
||||
og = operator.itemgetter(column_index)
|
||||
if type(row_list) != list:
|
||||
row_list = sorted(
|
||||
|
|
|
@ -36,16 +36,21 @@ def exponential_backoff_retry(
|
|||
num_retries=NUM_RETRIES,
|
||||
backoff_multiplier=BACKOFF_MULTIPLIER,
|
||||
max_delay_ms=MAX_DELAY_MS):
|
||||
"""decorator for exponential backoff retry
|
||||
"""Decorator for exponential backoff retry.
|
||||
|
||||
Log and raise exception if unsuccessful
|
||||
Do not retry while in a session
|
||||
Log and raise exception if unsuccessful.
|
||||
Do not retry while in a session.
|
||||
|
||||
retry_exceptions: tuple of exceptions to check
|
||||
initial_delay_ms: initial delay between retries in ms
|
||||
num_retries: number max number of retries
|
||||
backoff_multipler: multiplier for each retry e.g. 2 will double the retry delay
|
||||
max_delay_ms: upper bound on retry delay
|
||||
Args:
|
||||
retry_exceptions: tuple of exceptions to check.
|
||||
initial_delay_ms: initial delay between retries in ms.
|
||||
num_retries: number max number of retries.
|
||||
backoff_multiplier: multiplier for each retry e.g. 2 will double the
|
||||
retry delay.
|
||||
max_delay_ms: upper bound on retry delay.
|
||||
|
||||
Returns:
|
||||
A decorator method that returns wrapped method.
|
||||
"""
|
||||
def decorator(method):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
|
@ -62,7 +67,9 @@ def exponential_backoff_retry(
|
|||
# and tablet_type from exception.
|
||||
log_exception(e)
|
||||
raise e
|
||||
logging.error("retryable error: %s, retrying in %d ms, attempt %d of %d", e, delay, attempt, num_retries)
|
||||
logging.error(
|
||||
"retryable error: %s, retrying in %d ms, attempt %d of %d", e,
|
||||
delay, attempt, num_retries)
|
||||
time.sleep(delay/1000.0)
|
||||
delay *= backoff_multiplier
|
||||
delay = min(max_delay_ms, delay)
|
||||
|
|
|
@ -12,14 +12,13 @@ from net import gorpc
|
|||
from vtdb import dbapi
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import field_types
|
||||
from vtdb import keyrange
|
||||
from vtdb import keyspace
|
||||
from vtdb import vtdb_logger
|
||||
from vtdb import vtgate_client
|
||||
from vtdb import vtgate_cursor
|
||||
from vtdb import vtgate_utils
|
||||
|
||||
_errno_pattern = re.compile('\(errno (\d+)\)')
|
||||
_errno_pattern = re.compile(r'\(errno (\d+)\)')
|
||||
|
||||
|
||||
def handle_app_error(exc_args):
|
||||
|
@ -41,6 +40,7 @@ def handle_app_error(exc_args):
|
|||
|
||||
def convert_exception(exc, *args, **kwargs):
|
||||
"""This parses the protocol exceptions to the api interface exceptions.
|
||||
|
||||
This also logs the exception and increments the appropriate error counters.
|
||||
|
||||
Args:
|
||||
|
@ -66,48 +66,53 @@ def convert_exception(exc, *args, **kwargs):
|
|||
elif isinstance(exc, gorpc.GoRpcError):
|
||||
new_exc = dbexceptions.FatalError(new_args)
|
||||
|
||||
keyspace = kwargs.get("keyspace", None)
|
||||
tablet_type = kwargs.get("tablet_type", None)
|
||||
keyspace_name = kwargs.get('keyspace', None)
|
||||
tablet_type = kwargs.get('tablet_type', None)
|
||||
|
||||
vtgate_utils.log_exception(new_exc, keyspace=keyspace,
|
||||
vtgate_utils.log_exception(new_exc, keyspace=keyspace_name,
|
||||
tablet_type=tablet_type)
|
||||
return new_exc
|
||||
|
||||
|
||||
def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspace_ids, not_in_transaction):
|
||||
def _create_req_with_keyspace_ids(
|
||||
sql, new_binds, keyspace, tablet_type, keyspace_ids, not_in_transaction):
|
||||
# keyspace_ids are Keyspace Ids packed to byte[]
|
||||
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
req = {
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyspaceIds': keyspace_ids,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyspaceIds': keyspace_ids,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
return req
|
||||
|
||||
|
||||
def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges, not_in_transaction):
|
||||
def _create_req_with_keyranges(
|
||||
sql, new_binds, keyspace, tablet_type, keyranges, not_in_transaction):
|
||||
# keyranges are keyspace.KeyRange objects with start/end packed to byte[]
|
||||
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
req = {
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyRanges': keyranges,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyRanges': keyranges,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
return req
|
||||
|
||||
|
||||
# A simple, direct connection to the vttablet query server.
|
||||
# This is shard-unaware and only handles the most basic communication.
|
||||
# If something goes wrong, this object should be thrown away and a new one instantiated.
|
||||
class VTGateConnection(vtgate_client.VTGateClient):
|
||||
"""A simple, direct connection to the vttablet query server.
|
||||
|
||||
This is shard-unaware and only handles the most basic communication.
|
||||
If something goes wrong, this object should be thrown away and a new
|
||||
one instantiated.
|
||||
"""
|
||||
session = None
|
||||
_stream_fields = None
|
||||
_stream_conversions = None
|
||||
|
@ -118,7 +123,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
keyfile=None, certfile=None):
|
||||
self.addr = addr
|
||||
self.timeout = timeout
|
||||
self.client = bsonrpc.BsonRpcClient(addr, timeout, user, password, keyfile=keyfile, certfile=certfile)
|
||||
self.client = bsonrpc.BsonRpcClient(
|
||||
addr, timeout, user, password, keyfile=keyfile, certfile=certfile)
|
||||
self.logger_object = vtdb_logger.get_logger()
|
||||
|
||||
def __str__(self):
|
||||
|
@ -204,13 +210,18 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
exec_method = None
|
||||
req = None
|
||||
if keyspace_ids is not None:
|
||||
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
|
||||
req = _create_req_with_keyspace_ids(
|
||||
sql, bind_variables, keyspace, tablet_type, keyspace_ids,
|
||||
not_in_transaction)
|
||||
exec_method = 'VTGate.ExecuteKeyspaceIds'
|
||||
elif keyranges is not None:
|
||||
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
|
||||
req = _create_req_with_keyranges(
|
||||
sql, bind_variables, keyspace, tablet_type, keyranges,
|
||||
not_in_transaction)
|
||||
exec_method = 'VTGate.ExecuteKeyRanges'
|
||||
else:
|
||||
raise dbexceptions.ProgrammingError('_execute called without specifying keyspace_ids or keyranges')
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'_execute called without specifying keyspace_ids or keyranges')
|
||||
|
||||
self._add_caller_id(req, effective_caller_id)
|
||||
self._add_session(req)
|
||||
|
@ -301,13 +312,13 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
raise
|
||||
return results, rowcount, lastrowid, fields
|
||||
|
||||
|
||||
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
|
||||
def _execute_batch(
|
||||
self, sql_list, bind_variables_list, keyspace_list, keyspace_ids_list,
|
||||
tablet_type, as_transaction, effective_caller_id=None):
|
||||
query_list = []
|
||||
for sql, bind_vars, keyspace, keyspace_ids in zip(sql_list, bind_variables_list, keyspace_list, keyspace_ids_list):
|
||||
for sql, bind_vars, keyspace, keyspace_ids in zip(
|
||||
sql_list, bind_variables_list, keyspace_list, keyspace_ids_list):
|
||||
sql, bind_vars = dbapi.prepare_query_bind_vars(sql, bind_vars)
|
||||
query = {}
|
||||
query['Sql'] = sql
|
||||
|
@ -329,7 +340,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
response = self.client.call('VTGate.ExecuteBatchKeyspaceIds', req)
|
||||
self._update_session(response)
|
||||
if 'Error' in response.reply and response.reply['Error']:
|
||||
raise gorpc.AppError(response.reply['Error'], 'VTGate.ExecuteBatchKeyspaceIds')
|
||||
raise gorpc.AppError(
|
||||
response.reply['Error'], 'VTGate.ExecuteBatchKeyspaceIds')
|
||||
for reply in response.reply['List']:
|
||||
fields = []
|
||||
conversions = []
|
||||
|
@ -365,13 +377,18 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
exec_method = None
|
||||
req = None
|
||||
if keyspace_ids is not None:
|
||||
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
|
||||
req = _create_req_with_keyspace_ids(
|
||||
sql, bind_variables, keyspace, tablet_type, keyspace_ids,
|
||||
not_in_transaction)
|
||||
exec_method = 'VTGate.StreamExecuteKeyspaceIds'
|
||||
elif keyranges is not None:
|
||||
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
|
||||
req = _create_req_with_keyranges(
|
||||
sql, bind_variables, keyspace, tablet_type, keyranges,
|
||||
not_in_transaction)
|
||||
exec_method = 'VTGate.StreamExecuteKeyRanges'
|
||||
else:
|
||||
raise dbexceptions.ProgrammingError('_stream_execute called without specifying keyspace_ids or keyranges')
|
||||
raise dbexceptions.ProgrammingError(
|
||||
'_stream_execute called without specifying keyspace_ids or keyranges')
|
||||
|
||||
self._add_caller_id(req, effective_caller_id)
|
||||
self._add_session(req)
|
||||
|
@ -387,7 +404,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
|
||||
for field in reply['Fields']:
|
||||
self._stream_fields.append((field['Name'], field['Type']))
|
||||
self._stream_conversions.append(field_types.conversions.get(field['Type']))
|
||||
self._stream_conversions.append(
|
||||
field_types.conversions.get(field['Type']))
|
||||
except gorpc.GoRpcError as e:
|
||||
self.logger_object.log_private_data(bind_variables)
|
||||
raise convert_exception(e, str(self), sql, keyspace_ids, keyranges,
|
||||
|
@ -410,7 +428,8 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
self._stream_result_index = None
|
||||
return None
|
||||
# A session message, if any comes separately with no rows
|
||||
if 'Session' in self._stream_result.reply and self._stream_result.reply['Session']:
|
||||
if ('Session' in self._stream_result.reply and
|
||||
self._stream_result.reply['Session']):
|
||||
self.session = self._stream_result.reply['Session']
|
||||
self._stream_result = None
|
||||
continue
|
||||
|
@ -420,11 +439,14 @@ class VTGateConnection(vtgate_client.VTGateClient):
|
|||
logging.exception('gorpc low-level error')
|
||||
raise
|
||||
|
||||
row = tuple(_make_row(self._stream_result.reply['Result']['Rows'][self._stream_result_index], self._stream_conversions))
|
||||
row = tuple(_make_row(
|
||||
self._stream_result.reply['Result']['Rows'][self._stream_result_index],
|
||||
self._stream_conversions))
|
||||
|
||||
# If we are reading the last row, set us up to read more data.
|
||||
self._stream_result_index += 1
|
||||
if self._stream_result_index == len(self._stream_result.reply['Result']['Rows']):
|
||||
if (self._stream_result_index ==
|
||||
len(self._stream_result.reply['Result']['Rows'])):
|
||||
self._stream_result = None
|
||||
self._stream_result_index = 0
|
||||
|
||||
|
@ -468,7 +490,7 @@ def get_params_for_vtgate_conn(vtgate_addrs, timeout, user=None, password=None):
|
|||
random.shuffle(vtgate_addrs)
|
||||
addrs = vtgate_addrs
|
||||
else:
|
||||
raise dbexceptions.Error("Wrong type for vtgate addrs %s" % vtgate_addrs)
|
||||
raise dbexceptions.Error('Wrong type for vtgate addrs %s' % vtgate_addrs)
|
||||
|
||||
for addr in addrs:
|
||||
vt_params = dict()
|
||||
|
@ -485,7 +507,9 @@ def connect(vtgate_addrs, timeout, user=None, password=None):
|
|||
user=user, password=password)
|
||||
|
||||
if not db_params_list:
|
||||
raise dbexceptions.OperationalError("empty db params list - no db instance available for vtgate_addrs %s" % vtgate_addrs)
|
||||
raise dbexceptions.OperationalError(
|
||||
'empty db params list - no db instance available for vtgate_addrs %s' %
|
||||
vtgate_addrs)
|
||||
|
||||
db_exception = None
|
||||
host_addr = None
|
||||
|
@ -501,6 +525,6 @@ def connect(vtgate_addrs, timeout, user=None, password=None):
|
|||
logging.warning('db connection failed: %s, %s', host_addr, e)
|
||||
|
||||
raise dbexceptions.OperationalError(
|
||||
'unable to create vt connection', host_addr, db_exception)
|
||||
'unable to create vt connection', host_addr, db_exception)
|
||||
|
||||
vtgate_client.register_conn_class('gorpc', VTGateConnection)
|
||||
|
|
|
@ -4,18 +4,17 @@
|
|||
|
||||
from itertools import izip
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
from net import bsonrpc
|
||||
from net import gorpc
|
||||
from vtdb import cursorv3
|
||||
from vtdb import dbexceptions
|
||||
from vtdb import field_types
|
||||
from vtdb import vtdb_logger
|
||||
from vtdb import cursorv3
|
||||
|
||||
|
||||
_errno_pattern = re.compile('\(errno (\d+)\)')
|
||||
_errno_pattern = re.compile(r'\(errno (\d+)\)')
|
||||
|
||||
|
||||
def log_exception(method):
|
||||
|
@ -26,8 +25,8 @@ def log_exception(method):
|
|||
method based on the exception raised.
|
||||
|
||||
Args:
|
||||
exc: exception raised by calling code
|
||||
args: additional args for the exception.
|
||||
method: Method that takes exc, *args, where exc is an exception raised
|
||||
by calling code, args are additional args for the exception.
|
||||
|
||||
Returns:
|
||||
Decorated method.
|
||||
|
@ -79,16 +78,16 @@ def convert_exception(exc, *args):
|
|||
def _create_req(sql, new_binds, tablet_type, not_in_transaction):
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
req = {
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'TabletType': tablet_type,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'TabletType': tablet_type,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
return req
|
||||
|
||||
|
||||
# This utilizes the V3 API of VTGate.
|
||||
class VTGateConnection(object):
|
||||
"""This utilizes the V3 API of VTGate."""
|
||||
session = None
|
||||
_stream_fields = None
|
||||
_stream_conversions = None
|
||||
|
@ -163,7 +162,8 @@ class VTGateConnection(object):
|
|||
if 'Session' in response.reply and response.reply['Session']:
|
||||
self.session = response.reply['Session']
|
||||
|
||||
def _execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
|
||||
def _execute(
|
||||
self, sql, bind_variables, tablet_type, not_in_transaction=False):
|
||||
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
|
||||
self._add_session(req)
|
||||
|
||||
|
@ -198,8 +198,8 @@ class VTGateConnection(object):
|
|||
raise
|
||||
return results, rowcount, lastrowid, fields
|
||||
|
||||
|
||||
def _execute_batch(self, sql_list, bind_variables_list, tablet_type, as_transaction):
|
||||
def _execute_batch(
|
||||
self, sql_list, bind_variables_list, tablet_type, as_transaction):
|
||||
query_list = []
|
||||
for sql, bind_vars in zip(sql_list, bind_variables_list):
|
||||
query = {}
|
||||
|
@ -247,7 +247,8 @@ class VTGateConnection(object):
|
|||
# we return the fields for the response, and the column conversions
|
||||
# the conversions will need to be passed back to _stream_next
|
||||
# (that way we avoid using a member variable here for such a corner case)
|
||||
def _stream_execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
|
||||
def _stream_execute(
|
||||
self, sql, bind_variables, tablet_type, not_in_transaction=False):
|
||||
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
|
||||
self._add_session(req)
|
||||
|
||||
|
@ -262,7 +263,8 @@ class VTGateConnection(object):
|
|||
|
||||
for field in reply['Fields']:
|
||||
self._stream_fields.append((field['Name'], field['Type']))
|
||||
self._stream_conversions.append(field_types.conversions.get(field['Type']))
|
||||
self._stream_conversions.append(
|
||||
field_types.conversions.get(field['Type']))
|
||||
except gorpc.GoRpcError as e:
|
||||
self.logger_object.log_private_data(bind_variables)
|
||||
raise convert_exception(e, str(self), sql)
|
||||
|
@ -284,7 +286,8 @@ class VTGateConnection(object):
|
|||
self._stream_result_index = None
|
||||
return None
|
||||
# A session message, if any comes separately with no rows
|
||||
if 'Session' in self._stream_result.reply and self._stream_result.reply['Session']:
|
||||
if ('Session' in self._stream_result.reply and
|
||||
self._stream_result.reply['Session']):
|
||||
self.session = self._stream_result.reply['Session']
|
||||
self._stream_result = None
|
||||
continue
|
||||
|
@ -298,11 +301,14 @@ class VTGateConnection(object):
|
|||
logging.exception('gorpc low-level error')
|
||||
raise
|
||||
|
||||
row = tuple(_make_row(self._stream_result.reply['Result']['Rows'][self._stream_result_index], self._stream_conversions))
|
||||
row = tuple(_make_row(
|
||||
self._stream_result.reply['Result']['Rows'][self._stream_result_index],
|
||||
self._stream_conversions))
|
||||
|
||||
# If we are reading the last row, set us up to read more data.
|
||||
self._stream_result_index += 1
|
||||
if self._stream_result_index == len(self._stream_result.reply['Result']['Rows']):
|
||||
if (self._stream_result_index ==
|
||||
len(self._stream_result.reply['Result']['Rows'])):
|
||||
self._stream_result = None
|
||||
self._stream_result_index = 0
|
||||
|
||||
|
|
|
@ -229,7 +229,7 @@ def _create_where_clause_for_str_keyspace(key_range, keyspace_col_name):
|
|||
i += 1
|
||||
bind_vars[bind_name] = kr_min
|
||||
if kr_max != keyrange_constants.MAX_KEY:
|
||||
if where_clause != '':
|
||||
if where_clause:
|
||||
where_clause += ' AND '
|
||||
bind_name = '%s%d' % (keyspace_col_name, i)
|
||||
where_clause += 'hex(%s) < ' % keyspace_col_name + '%(' + bind_name + ')s'
|
||||
|
@ -262,7 +262,7 @@ def _create_where_clause_for_int_keyspace(key_range, keyspace_col_name):
|
|||
i += 1
|
||||
bind_vars[bind_name] = kr_min
|
||||
if kr_max is not None:
|
||||
if where_clause != '':
|
||||
if where_clause:
|
||||
where_clause += ' AND '
|
||||
bind_name = '%s%d' % (keyspace_col_name, i)
|
||||
where_clause += '%s < ' % keyspace_col_name + '%(' + bind_name + ')s'
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
# Implement a sensible wrapper that treats python objects as dictionaries
|
||||
# with sensible restrictions on serialization.
|
||||
"""Implement a sensible wrapper that treats python objects as dictionaries.
|
||||
|
||||
Has sensible restrictions on serialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def _default(o):
|
||||
if hasattr(o, '_serializable_attributes'):
|
||||
return dict([(k, v)
|
||||
|
@ -13,13 +16,15 @@ def _default(o):
|
|||
_default_kargs = {'default': _default,
|
||||
'sort_keys': True,
|
||||
'indent': 2,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def dump(*pargs, **kargs):
|
||||
_kargs = _default_kargs.copy()
|
||||
_kargs.update(kargs)
|
||||
return json.dump(*pargs, **_kargs)
|
||||
|
||||
|
||||
def dumps(*pargs, **kargs):
|
||||
_kargs = _default_kargs.copy()
|
||||
_kargs.update(kargs)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# zkns - naming service.
|
||||
#
|
||||
# This uses zookeeper to resolve a list of servers to answer a
|
||||
# particular query type.
|
||||
#
|
||||
# Additionally, config information can be embedded nearby for future updates
|
||||
# to the python client.
|
||||
"""zkns - naming service.
|
||||
|
||||
This uses zookeeper to resolve a list of servers to answer a
|
||||
particular query type.
|
||||
|
||||
Additionally, config information can be embedded nearby for future updates
|
||||
to the python client.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
@ -16,6 +17,7 @@ from zk import zkjson
|
|||
SrvEntry = collections.namedtuple('SrvEntry',
|
||||
('host', 'port', 'priority', 'weight'))
|
||||
|
||||
|
||||
class ZknsError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -28,12 +30,13 @@ class ZknsAddr(zkjson.ZkJsonObject):
|
|||
class ZknsAddrs(zkjson.ZkJsonObject):
|
||||
# NOTE: attributes match Go implementation, hence capitalization
|
||||
_serializable_attributes = ('entries',)
|
||||
|
||||
def __init__(self):
|
||||
self.entries = []
|
||||
|
||||
|
||||
def _sorted_by_srv_priority(entries):
|
||||
# Priority is ascending, weight is descending.
|
||||
"""Priority is ascending, weight is descending."""
|
||||
entries.sort(key=lambda x: (x.priority, -x.weight))
|
||||
|
||||
priority_map = collections.defaultdict(list)
|
||||
|
@ -41,7 +44,7 @@ def _sorted_by_srv_priority(entries):
|
|||
priority_map[entry.priority].append(entry)
|
||||
|
||||
shuffled_entries = []
|
||||
for priority, priority_entries in sorted(priority_map.iteritems()):
|
||||
for unused_priority, priority_entries in sorted(priority_map.iteritems()):
|
||||
if len(priority_entries) <= 1:
|
||||
shuffled_entries.extend(priority_entries)
|
||||
continue
|
||||
|
@ -62,6 +65,7 @@ def _sorted_by_srv_priority(entries):
|
|||
|
||||
return shuffled_entries
|
||||
|
||||
|
||||
def _get_addrs(zconn, zk_path):
|
||||
data = zconn.get_data(zk_path)
|
||||
addrs = ZknsAddrs()
|
||||
|
@ -72,6 +76,7 @@ def _get_addrs(zconn, zk_path):
|
|||
addrs.entries.append(addr)
|
||||
return addrs
|
||||
|
||||
|
||||
# zkns_name: /zk/cell/vt/ns/path:_port - port is optional
|
||||
def lookup_name(zconn, zkns_name):
|
||||
if ':' in zkns_name:
|
||||
|
|
|
@ -7,6 +7,7 @@ import threading
|
|||
from net import bsonrpc
|
||||
from net import gorpc
|
||||
|
||||
|
||||
class ZkOccError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -35,9 +36,12 @@ class ZkOccError(Exception):
|
|||
# pzxid long
|
||||
#
|
||||
|
||||
# A simple, direct connection to a single zkocc server. Doesn't retry.
|
||||
# You probably want to use ZkOccConnection instead.
|
||||
|
||||
class SimpleZkOccConnection(object):
|
||||
"""A simple, direct connection to a single zkocc server.
|
||||
|
||||
Doesn't retry. You probably want to use ZkOccConnection instead.
|
||||
"""
|
||||
|
||||
def __init__(self, addr, timeout, user=None, password=None):
|
||||
self.client = bsonrpc.BsonRpcClient(addr, timeout, user, password)
|
||||
|
@ -75,13 +79,18 @@ class SimpleZkOccConnection(object):
|
|||
return self._call('TopoReader.GetSrvKeyspace', cell=cell, keyspace=keyspace)
|
||||
|
||||
def get_end_points(self, cell, keyspace, shard, tablet_type):
|
||||
return self._call('TopoReader.GetEndPoints', cell=cell, keyspace=keyspace, shard=shard, tablet_type=tablet_type)
|
||||
return self._call(
|
||||
'TopoReader.GetEndPoints', cell=cell, keyspace=keyspace, shard=shard,
|
||||
tablet_type=tablet_type)
|
||||
|
||||
|
||||
# A meta-connection that can connect to multiple alternate servers, and will
|
||||
# retry a couple times. Calling dial before get/getv/children is optional,
|
||||
# and will only do anything at all if authentication is enabled.
|
||||
class ZkOccConnection(object):
|
||||
"""A meta-connection that can connect to multiple alternate servers.
|
||||
|
||||
This will retry a couple times. Calling dial before
|
||||
get/getv/children is optional, and will only do anything at all if
|
||||
authentication is enabled.
|
||||
"""
|
||||
max_attempts = 2
|
||||
max_dial_attempts = 10
|
||||
|
||||
|
@ -92,7 +101,8 @@ class ZkOccConnection(object):
|
|||
self.local_cell = local_cell
|
||||
|
||||
if bool(user) != bool(password):
|
||||
raise ValueError("You must provide either both or none of user and password.")
|
||||
raise ValueError(
|
||||
'You must provide either both or none of user and password.')
|
||||
self.user = user
|
||||
self.password = password
|
||||
|
||||
|
@ -100,7 +110,7 @@ class ZkOccConnection(object):
|
|||
self.lock = threading.Lock()
|
||||
|
||||
def _resolve_path(self, zk_path):
|
||||
# Maps a 'meta-path' to a cell specific path.
|
||||
"""Maps a 'meta-path' to a cell specific path."""
|
||||
# '/zk/local/blah' -> '/zk/vb/blah'
|
||||
parts = zk_path.split('/')
|
||||
|
||||
|
@ -123,9 +133,11 @@ class ZkOccConnection(object):
|
|||
if self.simple_conn:
|
||||
self.simple_conn.close()
|
||||
|
||||
addrs = random.sample(self.addrs, min(self.max_dial_attempts, len(self.addrs)))
|
||||
addrs = random.sample(
|
||||
self.addrs, min(self.max_dial_attempts, len(self.addrs)))
|
||||
for a in addrs:
|
||||
self.simple_conn = SimpleZkOccConnection(a, self.timeout, self.user, self.password)
|
||||
self.simple_conn = SimpleZkOccConnection(
|
||||
a, self.timeout, self.user, self.password)
|
||||
try:
|
||||
self.simple_conn.dial()
|
||||
return
|
||||
|
@ -133,7 +145,7 @@ class ZkOccConnection(object):
|
|||
pass
|
||||
|
||||
self.simple_conn = None
|
||||
raise ZkOccError("Cannot dial to any server, tried: %s" % addrs)
|
||||
raise ZkOccError('Cannot dial to any server, tried: %s' % addrs)
|
||||
|
||||
def close(self):
|
||||
if self.simple_conn:
|
||||
|
@ -151,9 +163,13 @@ class ZkOccConnection(object):
|
|||
return getattr(self.simple_conn, client_method)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
logging.warning('zkocc: %s command failed %d times: %s', client_method, attempt, e)
|
||||
logging.warning(
|
||||
'zkocc: %s command failed %d times: %s',
|
||||
client_method, attempt, e)
|
||||
if attempt >= self.max_attempts:
|
||||
raise ZkOccError('zkocc %s command failed %d times: %s' % (client_method, attempt, e))
|
||||
raise ZkOccError(
|
||||
'zkocc %s command failed %d times: %s' %
|
||||
(client_method, attempt, e))
|
||||
|
||||
# try the next server if there is one, or retry our only server
|
||||
self.dial()
|
||||
|
@ -190,10 +206,15 @@ class ZkOccConnection(object):
|
|||
def children(self, path):
|
||||
return self._call('children', self._resolve_path(path))
|
||||
|
||||
# use this class for faking out a zkocc client. The startup config values
|
||||
# can be loaded from a json file. After that, they can be mass-altered
|
||||
# to replace default values with test-specific values, for instance.
|
||||
|
||||
class FakeZkOccConnection(object):
|
||||
"""Use this class for faking out a zkocc client.
|
||||
|
||||
The startup config values can be loaded from a json file. After
|
||||
that, they can be mass-altered to replace default values with
|
||||
test-specific values, for instance.
|
||||
"""
|
||||
|
||||
def __init__(self, local_cell):
|
||||
self.data = {}
|
||||
self.local_cell = local_cell
|
||||
|
@ -215,7 +236,7 @@ class FakeZkOccConnection(object):
|
|||
self.data[key] = data.replace(before, after)
|
||||
|
||||
def _resolve_path(self, zk_path):
|
||||
# Maps a 'meta-path' to a cell specific path.
|
||||
"""Maps a 'meta-path' to a cell specific path."""
|
||||
# '/zk/local/blah' -> '/zk/vb/blah'
|
||||
parts = zk_path.split('/')
|
||||
|
||||
|
@ -238,25 +259,25 @@ class FakeZkOccConnection(object):
|
|||
|
||||
def get(self, path):
|
||||
path = self._resolve_path(path)
|
||||
if not path in self.data:
|
||||
raise ZkOccError("FakeZkOccConnection: not found: " + path)
|
||||
if path not in self.data:
|
||||
raise ZkOccError('FakeZkOccConnection: not found: ' + path)
|
||||
return {
|
||||
'Data':self.data[path],
|
||||
'Children':[]
|
||||
'Data': self.data[path],
|
||||
'Children': []
|
||||
}
|
||||
|
||||
def getv(self, paths):
|
||||
raise ZkOccError("FakeZkOccConnection: not found: " + " ".join(paths))
|
||||
raise ZkOccError('FakeZkOccConnection: not found: ' + ' '.join(paths))
|
||||
|
||||
def children(self, path):
|
||||
path = self._resolve_path(path)
|
||||
children = [os.path.basename(node) for node in self.data
|
||||
if os.path.dirname(node) == path]
|
||||
if len(children) == 0:
|
||||
raise ZkOccError("FakeZkOccConnection: not found: " + path)
|
||||
if not children:
|
||||
raise ZkOccError('FakeZkOccConnection: not found: ' + path)
|
||||
return {
|
||||
'Data':'',
|
||||
'Children':children
|
||||
'Data': '',
|
||||
'Children': children
|
||||
}
|
||||
|
||||
# New API. For this fake object, it is based on the old API.
|
||||
|
@ -271,7 +292,7 @@ class FakeZkOccConnection(object):
|
|||
try:
|
||||
data = self.get(keyspace_path)['Data']
|
||||
if not data:
|
||||
raise ZkOccError("FakeZkOccConnection: empty keyspace: " + keyspace)
|
||||
raise ZkOccError('FakeZkOccConnection: empty keyspace: ' + keyspace)
|
||||
result = json.loads(data)
|
||||
# for convenience, we store the KeyRange as hex, but we need to
|
||||
# decode it here, as BSON RPC sends it as binary.
|
||||
|
@ -289,7 +310,7 @@ class FakeZkOccConnection(object):
|
|||
try:
|
||||
data = self.get(zk_path)['Data']
|
||||
if not data:
|
||||
raise ZkOccError("FakeZkOccConnection: empty end point: " + zk_path)
|
||||
raise ZkOccError('FakeZkOccConnection: empty end point: ' + zk_path)
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
raise ZkOccError('FakeZkOccConnection: invalid end point', zk_path, e)
|
||||
|
|
Загрузка…
Ссылка в новой задаче