diff --git a/py/vtdb/vtgate_cursor.py b/py/vtdb/vtgate_cursor.py index 843482d558..045eb61fcf 100644 --- a/py/vtdb/vtgate_cursor.py +++ b/py/vtdb/vtgate_cursor.py @@ -31,6 +31,11 @@ class VTGateCursor(object): connection = None description = None index = None + keyspace = None + tablet_type = None + keyspace_ids = None + keyranges = None + writeable = None def __init__(self, connection, keyspace, tablet_type, keyspace_ids=None, keyranges=None, writeable=False): self.connection = connection @@ -75,7 +80,7 @@ class VTGateCursor(object): return write_query = bool(write_sql_pattern.match(sql)) - # This check may also be done at high-layers but adding it here for completion. + # NOTE: This check may also be done at high-layers but adding it here for completion. if write_query: if not self.is_writable(): raise dbexceptions.DatabaseError('DML on a non-writable cursor', sql) @@ -88,7 +93,7 @@ class VTGateCursor(object): if keyrange is None or keyrange != keyrange_constants.NON_PARTIAL_KEYRANGE: raise dbexceptions.ProgrammingError('Keyrange not correct for non-sharded keyspace') - # FIXME(shrutip): this could potentially be done on vtgate server. + # FIXME(shrutip): migrate this to vtgate server. It is better done there. sql += self._binlog_hint(keyspace_ids[0]) self.results, self.rowcount, self.lastrowid, self.description = self.connection._execute(sql, @@ -96,8 +101,7 @@ class VTGateCursor(object): self.keyspace, self.tablet_type, keyspace_ids=self.keyspace_ids, - keyranges=self.keyranges, - **kargs) + keyranges=self.keyranges) self.index = 0 return self.rowcount @@ -107,6 +111,11 @@ class VTGateCursor(object): self.description = None self.lastrowid = None + # This is by definition a scatter query, so raise exception. + write_query = bool(write_sql_pattern.match(sql)) + if write_query: + raise dbexceptions.DatabaseError('execute_entity_ids is not allowed for write queries') + self.results, self.rowcount, self.lastrowid, self.description = self.connection._execute_entity_ids(sql, bind_variables, self.keyspace, @@ -230,8 +239,7 @@ class StreamVTGateCursor(VTGateCursor): self.keyspace, self.tablet_type, keyspace_ids=self.keyspace_ids, - keyranges=self.keyranges, - **kargs) + keyranges=self.keyranges) self.index = 0 return 0 diff --git a/py/vtdb/vtgatev2.py b/py/vtdb/vtgatev2.py index d276243875..aea1093245 100644 --- a/py/vtdb/vtgatev2.py +++ b/py/vtdb/vtgatev2.py @@ -8,7 +8,6 @@ import re from net import bsonrpc from net import gorpc -from vtdb import cursor from vtdb import dbexceptions from vtdb import field_types @@ -76,10 +75,8 @@ def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges) # A simple, direct connection to the vttablet query server. # This is shard-unaware and only handles the most basic communication. # If something goes wrong, this object should be thrown away and a new one instantiated. -class VtgateConnection(object): +class VTGateConnection(object): session = None - tablet_type = None - cursorclass = cursor.TabletCursor _stream_fields = None _stream_conversions = None _stream_result = None @@ -91,7 +88,7 @@ class VtgateConnection(object): self.client = bsonrpc.BsonRpcClient(addr, timeout, user, password, encrypted=encrypted, keyfile=keyfile, certfile=certfile) def __str__(self): - return '' % self.addr + return '' % self.addr def dial(self): try: @@ -132,15 +129,6 @@ class VtgateConnection(object): except gorpc.GoRpcError as e: raise convert_exception(e, str(self)) - def cursor(self, cursorclass=None, **kargs): - if cursorclass is not None: - # cursorclass can only be overwritten by a compatible cursor - if cursorclass != cursor.StreamCursor: - raise DatabaseException('invalid cursor type for VtgateConnection', - cursorclass) - return cursorclass(self, **kargs) - return self.cursorclass(self, **kargs) - def _add_session(self, req): if self.session: req['Session'] = self.session @@ -402,7 +390,7 @@ def connect(vtgate_addrs, timeout, encrypted=False, user=None, password=None): try: db_params = params.copy() host_addr = db_params['addr'] - conn = VtgateConnection(**db_params) + conn = VTGateConnection(**db_params) conn.dial() return conn except Exception as e: