From e49f70fe9350205c2cfe1183d19f4275d725e813 Mon Sep 17 00:00:00 2001 From: Jing Wang Date: Fri, 31 Jan 2014 23:13:31 -0800 Subject: [PATCH] Add Hive backend, refactor common code, fix bugs --- README.md | 4 +- pyhive/common.py | 216 ++++++++++++++++++++++++++ pyhive/exc.py | 70 +++++++++ pyhive/hive.py | 245 +++++++++++++++++++++++++++++ pyhive/presto.py | 264 ++------------------------------ pyhive/tests/dbapi_test_case.py | 150 ++++++++++++++++++ pyhive/tests/test_hive.py | 69 +++++++++ pyhive/tests/test_presto.py | 131 +++------------- scripts/make_many_rows.sh | 14 ++ scripts/make_one_row.sh | 4 + scripts/make_one_row_complex.sh | 4 + scripts/make_test_tables.sh | 7 + 12 files changed, 813 insertions(+), 365 deletions(-) create mode 100644 pyhive/common.py create mode 100644 pyhive/exc.py create mode 100644 pyhive/hive.py create mode 100644 pyhive/tests/dbapi_test_case.py create mode 100644 pyhive/tests/test_hive.py create mode 100755 scripts/make_many_rows.sh create mode 100755 scripts/make_one_row.sh create mode 100755 scripts/make_one_row_complex.sh create mode 100755 scripts/make_test_tables.sh diff --git a/README.md b/README.md index 2149ac2..7a89937 100644 --- a/README.md +++ b/README.md @@ -28,5 +28,5 @@ print select([func.count('*')], from_obj=user).scalar() Requirements ============ - Presto DBAPI: Just a Presto install - - Hive DBAPI: Thrift-generated `TCLIService` package - - SQLAlchemy integration: `sqlalchemy` version 0.5.8 + - Hive DBAPI: HiveServer2 daemon, `TCLIService`, `thrift`, `sasl`, `thrift_sasl` + - SQLAlchemy integration: `sqlalchemy` version 0.5 diff --git a/pyhive/common.py b/pyhive/common.py new file mode 100644 index 0000000..927365e --- /dev/null +++ b/pyhive/common.py @@ -0,0 +1,216 @@ +""" +Package private common utilities. Do not use directly. +""" + +from pyhive import exc +import abc +import collections +import time + + +class DBAPICursor(object): + """Base class for some common DBAPI logic""" + __metaclass__ = abc.ABCMeta + + _STATE_NONE = 0 + _STATE_RUNNING = 1 + _STATE_FINISHED = 2 + + def __init__(self, poll_interval=1): + self._poll_interval = poll_interval + self._reset_state() + + def _reset_state(self): + """Reset state about the previous query in preparation for running another query""" + # State to return as part of DBAPI + self._rownumber = 0 + + # Internal helper state + self._state = self._STATE_NONE + self._data = collections.deque() + self._columns = None + + def _fetch_while(self, fn): + while fn(): + self._fetch_more() + if fn(): + time.sleep(self._poll_interval) + + @abc.abstractproperty + def description(self): + raise NotImplementedError + + def close(self): + """By default, do nothing""" + pass + + @abc.abstractmethod + def _fetch_more(self): + """Get more results, append it to ``self._data``, and update ``self._state``.""" + raise NotImplementedError + + @property + def rowcount(self): + """By default, return -1 to indicate that this is not supported.""" + return -1 + + def executemany(self, operation, seq_of_parameters): + """Prepare a database operation (query or command) and then execute it against all parameter + sequences or mappings found in the sequence seq_of_parameters. + + Only the final result set is retained. + + Return values are not defined. + """ + for parameters in seq_of_parameters[:-1]: + self.execute(operation, parameters) + while self._state != self._STATE_FINISHED: + self._fetch_more() + if seq_of_parameters: + self.execute(operation, seq_of_parameters[-1]) + + def fetchone(self): + """Fetch the next row of a query result set, returning a single sequence, or None when no + more data is available. + + An Error (or subclass) exception is raised if the previous call to execute() did not + produce any result set or no call was issued yet. + """ + if self._state == self._STATE_NONE: + raise exc.ProgrammingError('No query yet') + + # Sleep until we're done or we have some data to return + self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED) + + if not self._data: + return None + else: + self._rownumber += 1 + return self._data.popleft() + + def fetchmany(self, size=None): + """Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a + list of tuples). An empty sequence is returned when no more rows are available. + + The number of rows to fetch per call is specified by the parameter. If it is not given, the + cursor's arraysize determines the number of rows to be fetched. The method should try to + fetch as many rows as indicated by the size parameter. If this is not possible due to the + specified number of rows not being available, fewer rows may be returned. + + An Error (or subclass) exception is raised if the previous call to .execute*() did not + produce any result set or no call was issued yet. + """ + if size is None: + size = self.arraysize + result = [] + for _ in xrange(size): + one = self.fetchone() + if one is None: + break + else: + result.append(one) + return result + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences + (e.g. a list of tuples). + + An Error (or subclass) exception is raised if the previous call to .execute*() did not + produce any result set or no call was issued yet. + """ + result = [] + while True: + one = self.fetchone() + if one is None: + break + else: + result.append(one) + return result + + @property + def arraysize(self): + """This read/write attribute specifies the number of rows to fetch at a time with + ``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time. + """ + return self._arraysize + + @arraysize.setter + def arraysize(self, value): + self._arraysize = value + + def setinputsizes(self, sizes): + """Does nothing by default""" + pass + + def setoutputsize(self, size, column=None): + """Does nothing by default""" + pass + + # + # Optional DB API Extensions + # + + @property + def rownumber(self): + """This read-only attribute should provide the current 0-based index of the cursor in the + result set. + + The index can be seen as index of the cursor in a sequence (the result set). The next fetch + operation will fetch the row indexed by .rownumber in that sequence. + """ + return self._rownumber + + def next(self): + """Return the next row from the currently executing SQL statement using the same semantics + as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted. + """ + one = self.fetchone() + if one is None: + raise StopIteration + else: + return one + + def __iter__(self): + """Return self to make cursors compatible to the iteration protocol.""" + return self + + +class DBAPITypeObject(object): + # Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints + def __init__(self, *values): + self.values = values + + def __cmp__(self, other): + if other in self.values: + return 0 + if other < self.values: + return 1 + else: + return -1 + + +class ParamEscaper(object): + def escape_args(self, parameters): + if isinstance(parameters, dict): + return {k: self.escape_item(v) for k, v in parameters.iteritems()} + elif isinstance(parameters, (list, tuple)): + return tuple(self.escape_item(x) for x in parameters) + else: + raise exc.ProgrammingError("Unsupported param format: {}".format(parameters)) + + def escape_number(self, item): + return item + + def escape_string(self, item): + # This is good enough when backslashes are literal, newlines are just followed, and the way + # to escape a single quote is to put two single quotes. + # (i.e. only special character is single quote) + return "'{}'".format(item.replace("'", "''")) + + def escape_item(self, item): + if isinstance(item, (int, long, float)): + return self.escape_number(item) + elif isinstance(item, basestring): + return self.escape_string(item) + else: + raise exc.ProgrammingError("Unsupported object {}".format(item)) diff --git a/pyhive/exc.py b/pyhive/exc.py new file mode 100644 index 0000000..01ec361 --- /dev/null +++ b/pyhive/exc.py @@ -0,0 +1,70 @@ +""" +Package private common utilities. Do not use directly. +""" + +__all__ = [ + 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError', + 'ProgrammingError', 'DataError', 'NotSupportedError', +] + + +class Error(StandardError): + """Exception that is the base class of all other error exceptions. + + You can use this to catch all errors with one single except statement. + """ + pass + + +class Warning(StandardError): + """Exception raised for important warnings like data truncations while inserting, etc.""" + pass + + +class InterfaceError(Error): + """Exception raised for errors that are related to the database interface rather than the + database itself. + """ + pass + + +class DatabaseError(Error): + """Exception raised for errors that are related to the database.""" + pass + + +class InternalError(DatabaseError): + """Exception raised when the database encounters an internal error, e.g. the cursor is not valid + anymore, the transaction is out of sync, etc.""" + pass + + +class OperationalError(DatabaseError): + """Exception raised for errors that are related to the database's operation and not necessarily + under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name + is not found, a transaction could not be processed, a memory allocation error occurred during + processing, etc. + """ + pass + + +class ProgrammingError(DatabaseError): + """Exception raised for programming errors, e.g. table not found or already exists, syntax error + in the SQL statement, wrong number of parameters specified, etc. + """ + pass + + +class DataError(DatabaseError): + """Exception raised for errors that are due to problems with the processed data like division by + zero, numeric value out of range, etc. + """ + pass + + +class NotSupportedError(DatabaseError): + """Exception raised in case a method or database API was used which is not supported by the + database, e.g. requesting a .rollback() on a connection that does not support transaction or + has transactions turned off. + """ + pass diff --git a/pyhive/hive.py b/pyhive/hive.py new file mode 100644 index 0000000..bdc2f05 --- /dev/null +++ b/pyhive/hive.py @@ -0,0 +1,245 @@ +"""DB API implementation backed by HiveServer2 (Thrift API) + +See http://www.python.org/dev/peps/pep-0249/ + +Many docstrings in this file are based on the PEP, which is in the public domain. +""" + +from TCLIService import TCLIService +from TCLIService import constants +from TCLIService import ttypes +from pyhive import common +from pyhive.common import DBAPITypeObject +# Make all exceptions visible in this module per DBAPI +from pyhive.exc import * +import getpass +import logging +import sasl +import sys +import thrift.protocol.TBinaryProtocol +import thrift.transport.TSocket +import thrift_sasl + +# PEP 249 module globals +apilevel = '2.0' +threadsafety = 2 # Threads may share the module and connections. +paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s + +_logger = logging.getLogger(__name__) + + +class HiveParamEscaper(common.ParamEscaper): + def escape_string(self, item): + # backslashes and single quotes need to be escaped + # TODO verify against parser + return "'{}'".format(item + .replace('\\', '\\\\') + .replace("'", "\\'") + .replace('\r', '\\r') + .replace('\n', '\\n') + .replace('\t', '\\t') + ) + +_escaper = HiveParamEscaper() + + +def connect(**kwargs): + """Constructor for creating a connection to the database. See class Connection for arguments. + + Returns a Connection object. + """ + return Connection(**kwargs) + + +class Connection(object): + """Wraps a Thrift session""" + + def __init__(self, host, port=10000, username=None, configuration=None): + socket = thrift.transport.TSocket.TSocket(host, port) + username = username or getpass.getuser() + configuration = configuration or {} + + def sasl_factory(): + sasl_client = sasl.Client() + sasl_client.setAttr('username', username) + # Password doesn't matter in PLAIN mode, just needs to be nonempty. + sasl_client.setAttr('password', 'x') + sasl_client.init() + return sasl_client + + # PLAIN corresponds to hive.server2.authentication=NONE in hive-site.xml + self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, 'PLAIN', socket) + protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) + self._client = TCLIService.Client(protocol) + + try: + self._transport.open() + open_session_req = ttypes.TOpenSessionReq( + client_protocol=ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1, + configuration=configuration, + ) + response = self._client.OpenSession(open_session_req) + _check_status(response) + assert(response.sessionHandle is not None), "Expected a session from OpenSession" + self._sessionHandle = response.sessionHandle + assert(response.serverProtocolVersion == ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1), \ + "Unable to handle protocol version {}".format(response.serverProtocolVersion) + except: + self._transport.close() + raise + + def close(self): + """Close the underlying session and Thrift transport""" + req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle) + response = self._client.CloseSession(req) + self._transport.close() + _check_status(response) + + def commit(self): + """Hive does not support transactions""" + pass + + def cursor(self): + """Return a new Cursor object using the connection.""" + return Cursor(self) + + @property + def client(self): + return self._client + + @property + def sessionHandle(self): + return self._sessionHandle + + +class Cursor(common.DBAPICursor): + """These objects represent a database cursor, which is used to manage the context of a fetch + operation. + + Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately + visible by other cursors or connections. + """ + + def __init__(self, connection): + super(Cursor, self).__init__() + self._connection = connection + + def _reset_state(self): + """Reset state about the previous query in preparation for running another query""" + super(Cursor, self)._reset_state() + self._description = None + self._operationHandle = None + + @property + def description(self): + """This read-only attribute is a sequence of 7-item sequences. + + Each of these sequences contains information describing one result column: + + - name + - type_code + - display_size (None in current implementation) + - internal_size (None in current implementation) + - precision (None in current implementation) + - scale (None in current implementation) + - null_ok (always True in current implementation) + + The type_code can be interpreted by comparing it to the Type Objects specified in the + section below. + """ + if self._operationHandle is None: + return None + if self._description is None: + req = ttypes.TGetResultSetMetadataReq(self._operationHandle) + response = self._connection.client.GetResultSetMetadata(req) + _check_status(response) + columns = response.schema.columns + self._description = [] + for col in columns: + primary_type_entry = col.typeDesc.types[0] + if primary_type_entry.primitiveEntry is None: + # All fancy stuff maps to string + type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE] + else: + type_id = primary_type_entry.primitiveEntry.type + type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id] + self._description.append((col.columnName, type_code, None, None, None, None, True)) + return self._description + + def close(self): + """Close the operation handle""" + if self._operationHandle is not None: + request = ttypes.TCloseOperationReq(self._operationHandle) + response = self._connection.client.CloseOperation(request) + _check_status(response) + + def execute(self, operation, parameters=None): + """Prepare and execute a database operation (query or command). + + Return values are not defined. + """ + if self._state == self._STATE_RUNNING: + raise ProgrammingError("Already running a query") + + # Prepare statement + if parameters is None: + sql = operation + else: + sql = operation % _escaper.escape_args(parameters) + + self._reset_state() + + self._state = self._STATE_RUNNING + _logger.debug("Query: %s", sql) + + req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, sql) + response = self._connection.client.ExecuteStatement(req) + _check_status(response) + self._operationHandle = response.operationHandle + + def _fetch_more(self): + """Send another TFetchResultsReq and update state""" + assert(self._state == self._STATE_RUNNING) + req = ttypes.TFetchResultsReq( + operationHandle=self._operationHandle, + orientation=ttypes.TFetchOrientation.FETCH_NEXT, + maxRows=1000, + ) + response = self._connection.client.FetchResults(req) + _check_status(response) + # response.hasMoreRows seems to always be False, so we instead check the number of rows + #if not response.hasMoreRows: + if not response.results.rows: + self._state = self._STATE_FINISHED + for row in response.results.rows: + self._data.append([_unwrap_col_val(val) for val in row.colVals]) + + +# +# Type Objects and Constructors +# + + +for type_id in constants.PRIMITIVE_TYPES: + name = ttypes.TTypeId._VALUES_TO_NAMES[type_id] + setattr(sys.modules[__name__], name, DBAPITypeObject([name])) + + +# +# Private utilities +# + + +def _unwrap_col_val(val): + """Return the raw value from a TColumnValue instance.""" + for _, _, attr, _, _ in filter(None, ttypes.TColumnValue.thrift_spec): + val_obj = getattr(val, attr) + if val_obj: + return val_obj.value + raise DataError("Got empty column value {}".format(val)) + + +def _check_status(response): + """Raise an OperationalError if the status is not success""" + if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS: + raise OperationalError(response) diff --git a/pyhive/presto.py b/pyhive/presto.py index 804efc7..b41b287 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -5,12 +5,13 @@ See http://www.python.org/dev/peps/pep-0249/ Many docstrings in this file are based on the PEP, which is in the public domain. """ -import collections -import exceptions +from pyhive import common +from pyhive.common import DBAPITypeObject +# Make all exceptions visible in this module per DBAPI +from pyhive.exc import * import getpass import logging import requests -import time import urlparse @@ -20,6 +21,7 @@ threadsafety = 2 # Threads may share the module and connections. paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s _logger = logging.getLogger(__name__) +_escaper = common.ParamEscaper() def connect(**kwargs): @@ -53,16 +55,13 @@ class Connection(object): return Cursor(**self._params) -class Cursor(object): +class Cursor(common.DBAPICursor): """These objects represent a database cursor, which is used to manage the context of a fetch operation. Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately visible by other cursors or connections. """ - _STATE_NONE = 0 - _STATE_RUNNING = 1 - _STATE_FINISHED = 2 def __init__(self, host, port='8080', user=None, catalog='hive', schema='default', poll_interval=1, source='pyhive'): @@ -76,6 +75,7 @@ class Cursor(object): update, defaults to a second :param source: string -- arbitrary identifier (shows up in the Presto monitoring page) """ + super(Cursor, self).__init__(poll_interval) # Config self._host = host self._port = port @@ -90,21 +90,10 @@ class Cursor(object): def _reset_state(self): """Reset state about the previous query in preparation for running another query""" - # State to return as part of DBAPI - self._rownumber = 0 - - # Internal helper state - self._state = self._STATE_NONE + super(Cursor, self)._reset_state() self._nextUri = None - self._data = collections.deque() self._columns = None - def _fetch_while(self, fn): - while fn(): - self._fetch_more() - if fn(): - time.sleep(self._poll_interval) - @property def description(self): """This read-only attribute is a sequence of 7-item sequences. @@ -135,15 +124,6 @@ class Cursor(object): for col in self._columns ] - @property - def rowcount(self): - """Presto does not support rowcount""" - return -1 - - def close(self): - """Presto does not have anything to close""" - pass - def execute(self, operation, parameters=None): """Prepare and execute a database operation (query or command). @@ -163,7 +143,7 @@ class Cursor(object): if parameters is None: sql = operation else: - sql = operation % _escape_args(parameters) + sql = operation % _escaper.escape_args(parameters) self._reset_state() @@ -176,7 +156,7 @@ class Cursor(object): self._process_response(response) def _fetch_more(self): - """Fetch the next URI and udpate state""" + """Fetch the next URI and update state""" self._process_response(requests.get(self._nextUri)) def _process_response(self, response): @@ -199,238 +179,14 @@ class Cursor(object): assert not self._nextUri, 'Should not have nextUri if failed' raise DatabaseError(response_json['error']) - def executemany(self, operation, seq_of_parameters): - """Prepare a database operation (query or command) and then execute it against all parameter - sequences or mappings found in the sequence seq_of_parameters. - - Only the final result set is retained. - - Return values are not defined. - """ - for parameters in seq_of_parameters[:-1]: - self.execute(operation, parameters) - while self._state != self._STATE_FINISHED: - self._fetch_more() - self.execute(operation, seq_of_parameters[-1]) - - def fetchone(self): - """Fetch the next row of a query result set, returning a single sequence, or None when no - more data is available. - - An Error (or subclass) exception is raised if the previous call to execute() did not - produce any result set or no call was issued yet. - """ - if self._state == self._STATE_NONE: - raise ProgrammingError('No query yet') - # Note: all Presto statements produce a result set - # The CREATE TABLE statement produces a single bigint called 'rows' - - # Sleep until we're done or we have some data to return - self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED) - - if not self._data: - return None - else: - self._rownumber += 1 - return self._data.popleft() - - def fetchmany(self, size=None): - """Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a - list of tuples). An empty sequence is returned when no more rows are available. - - The number of rows to fetch per call is specified by the parameter. If it is not given, the - cursor's arraysize determines the number of rows to be fetched. The method should try to - fetch as many rows as indicated by the size parameter. If this is not possible due to the - specified number of rows not being available, fewer rows may be returned. - - An Error (or subclass) exception is raised if the previous call to .execute*() did not - produce any result set or no call was issued yet. - """ - if size is None: - size = self.arraysize - result = [] - for _ in xrange(size): - one = self.fetchone() - if one is None: - break - else: - result.append(one) - return result - - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences - (e.g. a list of tuples). - - An Error (or subclass) exception is raised if the previous call to .execute*() did not - produce any result set or no call was issued yet. - """ - result = [] - while True: - one = self.fetchone() - if one is None: - break - else: - result.append(one) - return result - - @property - def arraysize(self): - """This read/write attribute specifies the number of rows to fetch at a time with - ``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time. - - In our current implementation this parameter has no effect on actual fetching. - """ - return self._arraysize - - @arraysize.setter - def arraysize(self, value): - self._arraysize = value - - def setinputsizes(self, sizes): - """Does nothing for Presto""" - pass - - def setoutputsize(self, size, column=None): - """Does nothing for Presto""" - pass - - # - # Optional DB API Extensions - # - - @property - def rownumber(self): - """This read-only attribute should provide the current 0-based index of the cursor in the - result set. - - The index can be seen as index of the cursor in a sequence (the result set). The next fetch - operation will fetch the row indexed by .rownumber in that sequence. - """ - return self._rownumber - - def next(self): - """Return the next row from the currently executing SQL statement using the same semantics - as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted. - """ - one = self.fetchone() - if one is None: - raise StopIteration - else: - return one - - def __iter__(self): - """Return self to make cursors compatible to the iteration protocol.""" - return self - -# -# Exceptions -# - - -class Error(exceptions.StandardError): - """Exception that is the base class of all other error exceptions. - - You can use this to catch all errors with one single except statement. - """ - pass - - -class Warning(exceptions.StandardError): - """Exception raised for important warnings like data truncations while inserting, etc.""" - pass - - -class InterfaceError(Error): - """Exception raised for errors that are related to the database interface rather than the - database itself. - """ - pass - - -class DatabaseError(Error): - """Exception raised for errors that are related to the database.""" - pass - - -class InternalError(DatabaseError): - """Exception raised when the database encounters an internal error, e.g. the cursor is not valid - anymore, the transaction is out of sync, etc.""" - pass - - -class OperationalError(DatabaseError): - """Exception raised for errors that are related to the database's operation and not necessarily - under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name - is not found, a transaction could not be processed, a memory allocation error occurred during - processing, etc. - """ - pass - - -class ProgrammingError(DatabaseError): - """Exception raised for programming errors, e.g. table not found or already exists, syntax error - in the SQL statement, wrong number of parameters specified, etc. - """ - pass - - -class DataError(DatabaseError): - """Exception raised for errors that are due to problems with the processed data like division by - zero, numeric value out of range, etc. - """ - pass - - -class NotSupportedError(DatabaseError): - """Exception raised in case a method or database API was used which is not supported by the - database, e.g. requesting a .rollback() on a connection that does not support transaction or - has transactions turned off. - """ - pass - # # Type Objects and Constructors # -class DBAPITypeObject(object): - # Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints - def __init__(self, *values): - self.values = values - - def __cmp__(self, other): - if other in self.values: - return 0 - if other < self.values: - return 1 - else: - return -1 - # See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java FIXED_INT_64 = DBAPITypeObject(['bigint']) VARIABLE_BINARY = DBAPITypeObject(['varchar']) DOUBLE = DBAPITypeObject(['double']) BOOLEAN = DBAPITypeObject(['boolean']) - - -# -# Private utilities -# -def _escape_args(parameters): - if isinstance(parameters, dict): - return {k: _escape_item(v) for k, v in parameters.iteritems()} - elif isinstance(parameters, (list, tuple)): - return tuple(_escape_item(x) for x in parameters) - else: - raise ProgrammingError("Unsupported param format: {}".format(parameters)) - - -def _escape_item(item): - if isinstance(item, (int, long, float)): - return item - elif isinstance(item, basestring): - # TODO is this good enough? - return "'{}'".format(item.replace("'", "''")) - else: - raise ProgrammingError("Unsupported object {}".format(item)) diff --git a/pyhive/tests/dbapi_test_case.py b/pyhive/tests/dbapi_test_case.py new file mode 100644 index 0000000..8f7d189 --- /dev/null +++ b/pyhive/tests/dbapi_test_case.py @@ -0,0 +1,150 @@ +""" +Shared DBAPI test cases +""" +from pyhive import exc +import abc +import contextlib +import functools +import unittest + + +def with_cursor(fn): + """Pass a cursor to the given function and handle cleanup. + + The cursor is taken from ``self.connect()``. + """ + @functools.wraps(fn) + def wrapped_fn(self, *args, **kwargs): + with contextlib.closing(self.connect()) as connection: + with contextlib.closing(connection.cursor()) as cursor: + fn(self, cursor, *args, **kwargs) + return wrapped_fn + + +class DBAPITestCase(unittest.TestCase): + __metaclass__ = abc.ABCMeta + __test__ = False + + @abc.abstractmethod + def connect(self): + raise NotImplementedError + + @with_cursor + def test_fetchone(self, cursor): + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.rownumber, 0) + self.assertEqual(cursor.fetchone(), [1]) + self.assertEqual(cursor.rownumber, 1) + self.assertIsNone(cursor.fetchone()) + + @with_cursor + def test_fetchall(self, cursor): + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [[1]]) + cursor.execute('SELECT a FROM many_rows ORDER BY a') + self.assertEqual(cursor.fetchall(), [[i] for i in xrange(10000)]) + + @with_cursor + def test_iterator(self, cursor): + cursor.execute('SELECT * FROM one_row') + self.assertEqual(list(cursor), [[1]]) + self.assertRaises(StopIteration, cursor.next) + + @with_cursor + def test_description_initial(self, cursor): + self.assertIsNone(cursor.description) + + @with_cursor + def test_description_failed(self, cursor): + try: + cursor.execute('blah_blah') + except exc.DatabaseError: + pass + self.assertIsNone(cursor.description) + + @with_cursor + def test_bad_query(self, cursor): + def run(): + cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') + cursor.fetchone() + self.assertRaises(exc.DatabaseError, run) + + @with_cursor + def test_concurrent_execution(self, cursor): + cursor.execute('SELECT * FROM one_row') + self.assertRaises(exc.ProgrammingError, + lambda: cursor.execute('SELECT * FROM one_row')) + + @with_cursor + def test_executemany(self, cursor): + for length in 1, 2: + cursor.executemany( + 'SELECT %(x)d FROM one_row', + [{'x': i} for i in xrange(1, length + 1)] + ) + self.assertEqual(cursor.fetchall(), [[length]]) + + @with_cursor + def test_executemany_none(self, cursor): + cursor.executemany('should_never_get_used', []) + self.assertIsNone(cursor.description) + self.assertRaises(exc.ProgrammingError, cursor.fetchone) + + @with_cursor + def test_fetchone_no_data(self, cursor): + self.assertRaises(exc.ProgrammingError, cursor.fetchone) + + @with_cursor + def test_fetchmany(self, cursor): + cursor.execute('SELECT * FROM many_rows LIMIT 15') + self.assertEqual(cursor.fetchmany(0), []) + self.assertEqual(len(cursor.fetchmany(10)), 10) + self.assertEqual(len(cursor.fetchmany(10)), 5) + + @with_cursor + def test_arraysize(self, cursor): + cursor.arraysize = 5 + cursor.execute('SELECT * FROM many_rows LIMIT 20') + self.assertEqual(len(cursor.fetchmany()), 5) + + @with_cursor + def test_polling_loop(self, cursor): + """Try to trigger the polling logic in fetchone()""" + cursor._poll_interval = 0 + cursor.execute('SELECT COUNT(*) FROM many_rows') + self.assertEqual(cursor.fetchone(), [10000]) + + @with_cursor + def test_no_params(self, cursor): + cursor.execute("SELECT '%(x)s' FROM one_row") + self.assertEqual(cursor.fetchall(), [['%(x)s']]) + + def test_escape(self): + """Verify that funny characters can be escaped as strings and SELECTed back""" + bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t ''' + self.run_escape_case(bad_str) + + @with_cursor + def run_escape_case(self, cursor, bad_str): + cursor.execute( + 'SELECT %d, %s FROM one_row', + (1, bad_str) + ) + self.assertEqual(cursor.fetchall(), [[1, bad_str]]) + cursor.execute( + 'SELECT %(a)d, %(b)s FROM one_row', + {'a': 1, 'b': bad_str} + ) + self.assertEqual(cursor.fetchall(), [[1, bad_str]]) + + @with_cursor + def test_invalid_params(self, cursor): + self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi')) + self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [{}])) + + def test_open_close(self): + with contextlib.closing(self.connect()): + pass + with contextlib.closing(self.connect()) as connection: + with contextlib.closing(connection.cursor()): + pass diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py new file mode 100644 index 0000000..5b9df1b --- /dev/null +++ b/pyhive/tests/test_hive.py @@ -0,0 +1,69 @@ +"""Hive integration tests. + +These rely on having a Hive+Hadoop cluster set up with HiveServer2 running. +They also require a tables created by make_test_tables.sh. +""" +from TCLIService import ttypes +from pyhive import exc +from pyhive import hive +from pyhive.tests.dbapi_test_case import DBAPITestCase +import mock +import contextlib +from pyhive.tests.dbapi_test_case import with_cursor + +_HOST = 'localhost' + + +class TestHive(DBAPITestCase): + __test__ = True + + def connect(self): + return hive.connect(host=_HOST, username='hadoop') + + @with_cursor + def test_description(self, cursor): + cursor.execute('SELECT * FROM one_row') + desc = [('number_of_rows', 'INT_TYPE', None, None, None, None, True)] + self.assertEqual(cursor.description, desc) + self.assertEqual(cursor.description, desc) + + @with_cursor + def test_complex(self, cursor): + cursor.execute('SELECT * FROM one_row_complex') + self.assertEqual(cursor.description, [ + ('a', 'STRING_TYPE', None, None, None, None, True), + ('b', 'STRING_TYPE', None, None, None, None, True), + ]) + self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']]) + + def test_noops(self): + """The DB-API specification requires that certain actions exist, even though they might not + be applicable.""" + # Wohoo inflating coverage stats! + with contextlib.closing(self.connect()) as connection: + with contextlib.closing(connection.cursor()) as cursor: + self.assertEqual(cursor.rowcount, -1) + cursor.setinputsizes([]) + cursor.setoutputsize(1, 'blah') + connection.commit() + + @mock.patch('TCLIService.TCLIService.Client.OpenSession') + def test_open_failed(self, open_session): + open_session.return_value.serverProtocolVersion = \ + ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1 + self.assertRaises(exc.OperationalError, self.connect) + + def test_escape(self): + # Hive thrift translates newlines into multiple rows. WTF. + bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t ''' + self.run_escape_case(bad_str) + + def test_newlines(self): + """Verify that newlines are passed through in a way that doesn't fail parsing""" + # Hive thrift translates newlines into multiple rows. WTF. + cursor = self.connect().cursor() + cursor.execute( + 'SELECT %s FROM one_row', + (' \r\n \r \n ',) + ) + self.assertEqual(cursor.fetchall(), [[' '], [' '], [' '], [' ']]) diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index d3a5102..3f4ff01 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -1,137 +1,50 @@ """Presto integration tests. -These rely on having a Presto+Hadoop cluster set up. They also require a table called one_row. +These rely on having a Presto+Hadoop cluster set up. +They also require a tables created by make_test_tables.sh. """ +from pyhive import exc from pyhive import presto +from pyhive.tests.dbapi_test_case import DBAPITestCase +from pyhive.tests.dbapi_test_case import with_cursor import mock -import unittest _HOST = 'localhost' -_ONE_ROW_TABLE_NAME = 'one_row' -_BIG_TABLE_NAME = 'user' -class Testpresto(unittest.TestCase): - def test_fetchone(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select 1 from {}'.format(_ONE_ROW_TABLE_NAME)) - self.assertEqual(cursor.rownumber, 0) - self.assertEqual(cursor.fetchone(), [1]) - self.assertEqual(cursor.rownumber, 1) - self.assertIsNone(cursor.fetchone()) +class TestPresto(DBAPITestCase): + __test__ = True - def test_fetchall(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME)) - self.assertEqual(cursor.fetchall(), [[1]]) - - def test_iterator(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME)) - self.assertEqual(list(cursor), [[1]]) - self.assertRaises(StopIteration, lambda: cursor.next()) + def connect(self): + return presto.connect(host=_HOST) def test_description(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select 1 as foobar from {}'.format(_ONE_ROW_TABLE_NAME)) + cursor = self.connect().cursor() + cursor.execute('SELECT 1 AS foobar FROM one_row') self.assertEqual(cursor.description, [('foobar', 'bigint', None, None, None, None, True)]) - def test_description_initial(self): - cursor = presto.connect(host=_HOST).cursor() - self.assertIsNone(cursor.description) - - def test_description_failed(self): - cursor = presto.connect(host=_HOST).cursor() - try: - cursor.execute('blah_blah') - except presto.DatabaseError: - pass - self.assertIsNone(cursor.description) - - def test_bad_query(self): - cursor = presto.connect(host=_HOST).cursor() - - def run(): - cursor.execute('select does_not_exist from {}'.format(_ONE_ROW_TABLE_NAME)) - cursor.fetchone() - self.assertRaises(presto.DatabaseError, run) - - def test_concurrent_execution(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME)) - self.assertRaises(presto.ProgrammingError, - lambda: cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))) - - def test_executemany(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.executemany( - 'select %(x)d from {}'.format(_ONE_ROW_TABLE_NAME), - [{'x': 1}, {'x': 2}, {'x': 3}] - ) - self.assertEqual(cursor.fetchall(), [[3]]) - - def test_fetchone_no_data(self): - cursor = presto.connect(host=_HOST).cursor() - self.assertRaises(presto.ProgrammingError, lambda: cursor.fetchone()) - - def test_fetchmany(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select * from {} limit 15'.format(_BIG_TABLE_NAME)) - self.assertEqual(cursor.fetchmany(0), []) - self.assertEqual(len(cursor.fetchmany(10)), 10) - self.assertEqual(len(cursor.fetchmany(10)), 5) - - def test_arraysize(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.arraysize = 5 - cursor.execute('select * from {} limit 20'.format(_BIG_TABLE_NAME)) - self.assertEqual(len(cursor.fetchmany()), 5) - - def test_slow_query(self): - """Trigger the polling logic in fetchone()""" - cursor = presto.connect(host=_HOST).cursor() - cursor.execute('select count(*) from {}'.format(_BIG_TABLE_NAME)) - self.assertIsNotNone(cursor.fetchone()) - - def test_no_params(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute("select '%(x)s' from {}".format(_ONE_ROW_TABLE_NAME)) - self.assertEqual(cursor.fetchall(), [['%(x)s']]) + @with_cursor + def test_complex(self, cursor): + cursor.execute('SELECT * FROM one_row_complex') + self.assertEqual(cursor.description, [ + ('a', 'varchar', None, None, None, None, True), + ('b', 'varchar', None, None, None, None, True), + ]) + self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']]) def test_noops(self): """The DB-API specification requires that certain actions exist, even though they might not be applicable.""" # Wohoo inflating coverage stats! - connection = presto.connect(host=_HOST) + connection = self.connect() cursor = connection.cursor() self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) cursor.setoutputsize(1, 'blah') - cursor.close() connection.commit() - connection.close() - - def test_escape(self): - cursor = presto.connect(host=_HOST).cursor() - cursor.execute( - "select %d, %s from {}".format(_ONE_ROW_TABLE_NAME), - (1, "';") - ) - self.assertEqual(cursor.fetchall(), [[1, "';"]]) - cursor.execute( - "select %(a)d, %(b)s from {}".format(_ONE_ROW_TABLE_NAME), - {'a': 1, 'b': "';"} - ) - self.assertEqual(cursor.fetchall(), [[1, "';"]]) - - def test_invalid_params(self): - cursor = presto.connect(host=_HOST).cursor() - self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', 'hi')) - self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', [{}])) @mock.patch('requests.post') def test_non_200(self, post): - cursor = presto.connect(host=_HOST).cursor() + cursor = self.connect().cursor() post.return_value.status_code = 404 - self.assertRaises(presto.OperationalError, - lambda: cursor.execute('show tables')) + self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables')) diff --git a/scripts/make_many_rows.sh b/scripts/make_many_rows.sh new file mode 100755 index 0000000..0a72e86 --- /dev/null +++ b/scripts/make_many_rows.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +hive -e 'DROP TABLE IF EXISTS many_rows' +hive -e " +CREATE TABLE many_rows ( + a INT +) PARTITIONED BY ( + b STRING +) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE" + +temp_file=/tmp/pyhive_test_data_many_rows.tsv +seq 0 9999 > $temp_file +hive -e "LOAD DATA LOCAL INPATH '$temp_file' INTO TABLE many_rows PARTITION (b='blah')" +rm -f $temp_file diff --git a/scripts/make_one_row.sh b/scripts/make_one_row.sh new file mode 100755 index 0000000..9de1702 --- /dev/null +++ b/scripts/make_one_row.sh @@ -0,0 +1,4 @@ +#!/bin/bash -eux +hive -e 'DROP TABLE IF EXISTS one_row' +hive -e 'CREATE TABLE one_row (number_of_rows INT)' +hive -e 'INSERT OVERWRITE TABLE one_row SELECT COUNT(*) + 1 FROM one_row' diff --git a/scripts/make_one_row_complex.sh b/scripts/make_one_row_complex.sh new file mode 100755 index 0000000..74ea42c --- /dev/null +++ b/scripts/make_one_row_complex.sh @@ -0,0 +1,4 @@ +#!/bin/bash -eux +hive -e 'DROP TABLE IF EXISTS one_row_complex' +hive -e 'CREATE TABLE one_row_complex (a map, b array)' +hive -e "INSERT OVERWRITE TABLE one_row_complex SELECT map(1, 'a', 2, 'b'), array(1, 2, 3) FROM one_row" diff --git a/scripts/make_test_tables.sh b/scripts/make_test_tables.sh new file mode 100755 index 0000000..c6997a2 --- /dev/null +++ b/scripts/make_test_tables.sh @@ -0,0 +1,7 @@ +#!/bin/bash -eux +# Hive must be on the path for this script to work. +# WARNING: drops and recreates tables called one_row, one_row_complex, and many_rows. + +$(dirname $0)/make_one_row.sh +$(dirname $0)/make_one_row_complex.sh +$(dirname $0)/make_many_rows.sh