зеркало из https://github.com/github/vitess-gh.git
478 строки
16 KiB
Python
478 строки
16 KiB
Python
"""Module containing the base class for database classes and decorator for db method.
|
|
|
|
The base class DBObjectBase is the base class for all other database base classes.
|
|
It has methods for common database operations like select, insert, update and delete.
|
|
This module also contains the definition for ShardRouting which is used for determining
|
|
the routing of a query during cursor creation.
|
|
The module also has the db_class_method decorator and db_wrapper method which are
|
|
used for cursor creation and calling the database method.
|
|
"""
|
|
import functools
|
|
import struct
|
|
|
|
from vtdb import database_context
|
|
from vtdb import dbexceptions
|
|
from vtdb import db_validator
|
|
from vtdb import keyrange
|
|
from vtdb import keyrange_constants
|
|
from vtdb import shard_constants
|
|
from vtdb import sql_builder
|
|
from vtdb import vtgate_cursor
|
|
|
|
class __EmptyBindVariables(frozenset):
|
|
pass
|
|
EmptyBindVariables = __EmptyBindVariables()
|
|
|
|
|
|
class ShardRouting(object):
|
|
"""VTGate Shard Routing Class.
|
|
|
|
Attributes:
|
|
keyspace: keyspace where the table resides.
|
|
sharding_key: sharding key of the table.
|
|
keyrange: keyrange for the query.
|
|
entity_column_name: the name of the lookup based entity used for routing.
|
|
entity_id_sharding_key_map: this map is used for in clause queries.
|
|
shard_name: this is used to route queries for custom sharded keyspaces.
|
|
"""
|
|
def __init__(self, keyspace):
|
|
# keyspace of the table.
|
|
self.keyspace = keyspace
|
|
# sharding_key, entity_column_name and entity_id_sharding_key
|
|
# are used primarily for routing range-sharded keyspace queries.
|
|
self.sharding_key = None
|
|
self.entity_column_name = None
|
|
self.entity_id_sharding_key_map = None
|
|
self.keyrange = None
|
|
self.shard_name = None
|
|
|
|
|
|
def _is_iterable_container(x):
|
|
return hasattr(x, '__iter__')
|
|
|
|
|
|
INSERT_KW = "insert"
|
|
UPDATE_KW = "update"
|
|
DELETE_KW = "delete"
|
|
|
|
|
|
def is_dml(sql):
|
|
first_kw = sql.split(' ')[0]
|
|
first_kw = first_kw.lower()
|
|
if first_kw == INSERT_KW or first_kw == UPDATE_KW or first_kw == DELETE_KW:
|
|
return True
|
|
return False
|
|
|
|
|
|
def create_cursor_from_params(vtgate_conn, tablet_type, is_dml, table_class):
|
|
"""This method creates the cursor from the required params.
|
|
|
|
This is mainly used for creating lookup cursor during create_shard_routing,
|
|
as there is no real cursor available.
|
|
|
|
Args:
|
|
vtgate_conn: connection to vtgate server.
|
|
tablet_type: tablet type for the cursor.
|
|
is_dml: indicates writable cursor or not.
|
|
table_class: table for which the cursor is being created.
|
|
|
|
Returns:
|
|
cursor
|
|
"""
|
|
cursor = table_class.create_vtgate_cursor(vtgate_conn, tablet_type, is_dml)
|
|
return cursor
|
|
|
|
|
|
def create_cursor_from_old_cursor(old_cursor, table_class):
|
|
"""This method creates the cursor from an existing cursor.
|
|
|
|
This is mainly used for creating lookup cursor during db operations on
|
|
other database tables.
|
|
|
|
Args:
|
|
old_cursor: existing cursor from which important params are evaluated.
|
|
table_class: table for which the cursor is being created.
|
|
|
|
Returns:
|
|
cursor
|
|
"""
|
|
cursor = table_class.create_vtgate_cursor(old_cursor._conn,
|
|
old_cursor.tablet_type,
|
|
old_cursor.is_writable())
|
|
return cursor
|
|
|
|
|
|
def create_stream_cursor_from_cursor(original_cursor):
|
|
"""
|
|
This method creates streaming cursor from a regular cursor.
|
|
|
|
Args:
|
|
original_cursor: Cursor of VTGateCursor type
|
|
|
|
Returns:
|
|
Returns StreamVTGateCursor that is not writable.
|
|
"""
|
|
if not isinstance(original_cursor, vtgate_cursor.VTGateCursor):
|
|
raise dbexceptions.ProgrammingError(
|
|
"Original cursor should be of VTGateCursor type.")
|
|
stream_cursor = vtgate_cursor.StreamVTGateCursor(
|
|
original_cursor._conn, original_cursor.keyspace,
|
|
original_cursor.tablet_type,
|
|
keyspace_ids=original_cursor.keyspace_ids,
|
|
keyranges=original_cursor.keyranges,
|
|
writable=False)
|
|
return stream_cursor
|
|
|
|
|
|
def create_batch_cursor_from_cursor(original_cursor, writable=False):
|
|
"""
|
|
This method creates a batch cursor from a regular cursor.
|
|
|
|
Args:
|
|
original_cursor: Cursor of VTGateCursor type
|
|
|
|
Returns:
|
|
Returns BatchVTGateCursor that has same attributes as original_cursor.
|
|
"""
|
|
if not isinstance(original_cursor, vtgate_cursor.VTGateCursor):
|
|
raise dbexceptions.ProgrammingError(
|
|
"Original cursor should be of VTGateCursor type.")
|
|
batch_cursor = vtgate_cursor.BatchVTGateCursor(
|
|
original_cursor._conn, original_cursor.keyspace,
|
|
original_cursor.tablet_type,
|
|
keyspace_ids=original_cursor.keyspace_ids,
|
|
writable=writable)
|
|
return batch_cursor
|
|
|
|
|
|
def db_wrapper(method, **decorator_kwargs):
|
|
"""Decorator that is used to create the appropriate cursor
|
|
for the table and call the database method with it.
|
|
|
|
Args:
|
|
method: Method to decorate.
|
|
decorator_kwargs: Keyword args for db_wrapper.
|
|
|
|
Returns:
|
|
Decorated method.
|
|
"""
|
|
@functools.wraps(method)
|
|
def _db_wrapper(*pargs, **kwargs):
|
|
table_class = pargs[0]
|
|
write_method = kwargs.get("write_method", False)
|
|
if not issubclass(table_class, DBObjectBase):
|
|
raise dbexceptions.ProgrammingError(
|
|
"table class '%s' is not inherited from DBObjectBase" % table_class)
|
|
|
|
# Create the cursor using cursor_method
|
|
cursor_method = pargs[1]
|
|
cursor = cursor_method(table_class)
|
|
|
|
# DML verification.
|
|
if write_method:
|
|
if not cursor.is_writable():
|
|
raise dbexceptions.ProgrammingError(
|
|
"Executing dmls on a non-writable cursor is not allowed.")
|
|
if table_class.is_mysql_view:
|
|
raise dbexceptions.ProgrammingError(
|
|
"writes disabled on view", table_class)
|
|
|
|
if pargs[2:]:
|
|
return method(table_class, cursor, *pargs[2:], **kwargs)
|
|
else:
|
|
return method(table_class, cursor, **kwargs)
|
|
return _db_wrapper
|
|
|
|
|
|
def write_db_class_method(*pargs, **kwargs):
|
|
"""Used for DML methods. Calls db_class_method."""
|
|
kwargs["write_method"] = True
|
|
return db_class_method(*pargs, **kwargs)
|
|
|
|
|
|
def db_class_method(*pargs, **kwargs):
|
|
"""This function calls db_wrapper to create the appropriate cursor."""
|
|
return classmethod(db_wrapper(*pargs, **kwargs))
|
|
|
|
|
|
def execute_batch_read(cursor, query_list, bind_vars_list):
|
|
"""Method for executing select queries in batch.
|
|
|
|
Args:
|
|
cursor: original cursor - that is converted to read-only BatchVTGateCursor.
|
|
query_list: query_list.
|
|
bind_vars_list: bind variables list.
|
|
|
|
Returns:
|
|
Result of the form [[q1row1, q1row2,...], [q2row1, ...],..]
|
|
|
|
Raises:
|
|
dbexceptions.ProgrammingError when dmls are issued to read batch cursor.
|
|
"""
|
|
if not isinstance(cursor, vtgate_cursor.VTGateCursor):
|
|
raise dbexceptions.ProgrammingError(
|
|
"cursor is not of the type VTGateCursor.")
|
|
batch_cursor = create_batch_cursor_from_cursor(cursor)
|
|
for q, bv in zip(query_list, bind_vars_list):
|
|
if is_dml(q):
|
|
raise dbexceptions.ProgrammingError("Dml %s for read batch cursor." % q)
|
|
batch_cursor.execute(q, bv)
|
|
|
|
batch_cursor.flush()
|
|
rowsets = batch_cursor.rowsets
|
|
result = []
|
|
# rowset is of the type [(results, rowcount, lastrowid, fields),..]
|
|
for rowset in rowsets:
|
|
rowset_results = rowset[0]
|
|
fields = [f[0] for f in rowset[3]]
|
|
rows = []
|
|
for row in rowset_results:
|
|
rows.append(sql_builder.DBRow(fields, row))
|
|
result.append(rows)
|
|
return result
|
|
|
|
|
|
def execute_batch_write(cursor, query_list, bind_vars_list):
|
|
"""Method for executing dml queries in batch.
|
|
|
|
Args:
|
|
cursor: original cursor - that is converted to read-only BatchVTGateCursor.
|
|
query_list: query_list.
|
|
bind_vars_list: bind variables list.
|
|
|
|
Returns:
|
|
Result of the form [{'rowcount':rowcount, 'lastrowid':lastrowid}, ...]
|
|
since for dmls those two values are valuable.
|
|
|
|
Raises:
|
|
dbexceptions.ProgrammingError when non-dmls are issued to writable batch cursor.
|
|
"""
|
|
if not isinstance(cursor, vtgate_cursor.VTGateCursor):
|
|
raise dbexceptions.ProgrammingError(
|
|
"cursor is not of the type VTGateCursor.")
|
|
batch_cursor = create_batch_cursor_from_cursor(cursor, writable=True)
|
|
if batch_cursor.is_writable() and len(batch_cursor.keyspace_ids) != 1:
|
|
raise dbexceptions.ProgrammingError(
|
|
"writable batch execute can also execute on one keyspace_id.")
|
|
for q, bv in zip(query_list, bind_vars_list):
|
|
if not is_dml(q):
|
|
raise dbexceptions.ProgrammingError("query %s is not a dml" % q)
|
|
batch_cursor.execute(q, bv)
|
|
|
|
batch_cursor.flush()
|
|
|
|
rowsets = batch_cursor.rowsets
|
|
result = []
|
|
# rowset is of the type [(results, rowcount, lastrowid, fields),..]
|
|
for rowset in rowsets:
|
|
result.append({'rowcount':rowset[1], 'lastrowid':rowset[2]})
|
|
return result
|
|
|
|
|
|
class InvalidUtf8DbWrite(dbexceptions.Error):
|
|
"""Raised when an attempt to write invalid utf-8 to the DB is made.
|
|
"""
|
|
template = ("Attempt to write invalid utf-8 strings to the DB table '%s' in "
|
|
"columns %s")
|
|
|
|
def __init__(self, table_name, columns):
|
|
self.table_name = table_name
|
|
self.columns = columns
|
|
self.template_args = (table_name, columns)
|
|
super(InvalidUtf8DbWrite, self).__init__(self.template % self.template_args)
|
|
|
|
|
|
class DBObjectBase(object):
|
|
"""Base class for db classes.
|
|
|
|
This abstracts sharding information and provides helper methods
|
|
for common database access operations.
|
|
"""
|
|
keyspace = None
|
|
sharding = None
|
|
table_name = None
|
|
id_column_name = None
|
|
is_mysql_view = False
|
|
utf8_columns = None
|
|
|
|
|
|
@classmethod
|
|
def create_shard_routing(class_, *pargs, **kwargs):
|
|
"""This method is used to create ShardRouting object which is
|
|
used for determining routing attributes for the vtgate cursor.
|
|
|
|
Returns:
|
|
ShardRouting object.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_kwargs):
|
|
"""This creates the VTGateCursor object which is used to make
|
|
all the rpc calls to VTGate.
|
|
|
|
Args:
|
|
vtgate_conn: connection to vtgate.
|
|
tablet_type: tablet type to connect to.
|
|
is_dml: Makes the cursor writable, enforces appropriate constraints.
|
|
|
|
Returns:
|
|
VTGateCursor for the query.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def _validate_column_value_pairs_for_write(class_, **column_values):
|
|
invalid_columns = db_validator.invalid_utf8_columns(class_.utf8_columns or [],
|
|
column_values)
|
|
if invalid_columns:
|
|
exc = InvalidUtf8DbWrite(class_.table_name, invalid_columns)
|
|
raise exc
|
|
|
|
@db_class_method
|
|
def select_by_columns(class_, cursor, where_column_value_pairs,
|
|
columns_list=None, order_by=None, group_by=None,
|
|
limit=None):
|
|
if columns_list is None:
|
|
columns_list = class_.columns_list
|
|
query, bind_vars = class_.create_select_query(where_column_value_pairs,
|
|
columns_list=columns_list,
|
|
order_by=order_by,
|
|
group_by=group_by,
|
|
limit=limit)
|
|
|
|
rowcount = cursor.execute(query, bind_vars)
|
|
rows = cursor.fetchall()
|
|
return [sql_builder.DBRow(columns_list, row) for row in rows]
|
|
|
|
@classmethod
|
|
def create_insert_query(class_, **bind_vars):
|
|
class_._validate_column_value_pairs_for_write(**bind_vars)
|
|
return sql_builder.insert_query(class_.table_name,
|
|
class_.columns_list,
|
|
**bind_vars)
|
|
|
|
@classmethod
|
|
def create_update_query(class_, where_column_value_pairs,
|
|
update_column_value_pairs):
|
|
class_._validate_column_value_pairs_for_write(
|
|
**dict(update_column_value_pairs))
|
|
return sql_builder.update_columns_query(
|
|
class_.table_name, where_column_value_pairs,
|
|
update_column_value_pairs=update_column_value_pairs)
|
|
|
|
@classmethod
|
|
def create_delete_query(class_, where_column_value_pairs, limit=None):
|
|
return sql_builder.delete_by_columns_query(class_.table_name,
|
|
where_column_value_pairs,
|
|
limit=limit)
|
|
|
|
@classmethod
|
|
def create_select_query(class_, where_column_value_pairs, columns_list=None,
|
|
order_by=None, group_by=None, limit=None):
|
|
if class_.columns_list is None:
|
|
raise dbexceptions.ProgrammingError("DB class should define columns_list")
|
|
|
|
if columns_list is None:
|
|
columns_list = class_.columns_list
|
|
|
|
query, bind_vars = sql_builder.select_by_columns_query(columns_list,
|
|
class_.table_name,
|
|
where_column_value_pairs,
|
|
order_by=order_by,
|
|
group_by=group_by,
|
|
limit=limit)
|
|
return query, bind_vars
|
|
|
|
@write_db_class_method
|
|
def insert(class_, cursor, **bind_vars):
|
|
if class_.columns_list is None:
|
|
raise dbexceptions.ProgrammingError("DB class should define columns_list")
|
|
|
|
query, bind_vars = class_.create_insert_query(**bind_vars)
|
|
cursor.execute(query, bind_vars)
|
|
return cursor.lastrowid
|
|
|
|
@write_db_class_method
|
|
def update_columns(class_, cursor, where_column_value_pairs,
|
|
update_column_value_pairs):
|
|
query, bind_vars = class_.create_update_query(
|
|
where_column_value_pairs,
|
|
update_column_value_pairs=update_column_value_pairs)
|
|
|
|
return cursor.execute(query, bind_vars)
|
|
|
|
@write_db_class_method
|
|
def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):
|
|
if not where_column_value_pairs:
|
|
raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")
|
|
|
|
query, bind_vars = class_.create_delete_query(where_column_value_pairs,
|
|
limit=limit)
|
|
cursor.execute(query, bind_vars)
|
|
if cursor.rowcount == 0:
|
|
raise dbexceptions.DatabaseError("DB Row not found")
|
|
return cursor.rowcount
|
|
|
|
@db_class_method
|
|
def select_by_columns_streaming(class_, cursor, where_column_value_pairs,
|
|
columns_list=None, order_by=None, group_by=None,
|
|
limit=None, fetch_size=100):
|
|
query, bind_vars = class_.create_select_query(where_column_value_pairs,
|
|
columns_list=columns_list,
|
|
order_by=order_by,
|
|
group_by=group_by,
|
|
limit=limit)
|
|
|
|
return class_._stream_fetch(cursor, query, bind_vars, fetch_size)
|
|
|
|
@classmethod
|
|
def _stream_fetch(class_, cursor, query, bind_vars, fetch_size=100):
|
|
stream_cursor = create_stream_cursor_from_cursor(cursor)
|
|
stream_cursor.execute(query, bind_vars)
|
|
while True:
|
|
rows = stream_cursor.fetchmany(size=fetch_size)
|
|
|
|
# NOTE: fetchmany returns an empty list when there are no more items.
|
|
# But an empty generator is still "true", so we have to count if we
|
|
# actually returned anything.
|
|
i = 0
|
|
for r in rows:
|
|
i += 1
|
|
yield sql_builder.DBRow(class_.columns_list, r)
|
|
if i == 0:
|
|
break
|
|
stream_cursor.close()
|
|
|
|
@db_class_method
|
|
def get_count(class_, cursor, column_value_pairs=None, **columns):
|
|
if not column_value_pairs:
|
|
column_value_pairs = columns.items()
|
|
column_value_pairs.sort()
|
|
query, bind_vars = sql_builder.build_count_query(class_.table_name,
|
|
column_value_pairs)
|
|
cursor.execute(query, bind_vars)
|
|
return cursor.fetch_aggregate_function(sum)
|
|
|
|
@db_class_method
|
|
def get_min(class_, cursor):
|
|
if class_.id_column_name is None:
|
|
raise dbexceptions.ProgrammingError("id_column_name not set.")
|
|
|
|
query = sql_builder.build_aggregate_query(class_.table_name,
|
|
class_.id_column_name)
|
|
|
|
cursor.execute(query, EmptyBindVariables)
|
|
return cursor.fetch_aggregate_function(min)
|
|
|
|
@db_class_method
|
|
def get_max(class_, cursor):
|
|
if class_.id_column_name is None:
|
|
raise dbexceptions.ProgrammingError("id_column_name not set.")
|
|
|
|
query = sql_builder.build_aggregate_query(class_.table_name,
|
|
class_.id_column_name,
|
|
sort_func='max')
|
|
cursor.execute(query, EmptyBindVariables)
|
|
return cursor.fetch_aggregate_function(max)
|