зеркало из https://github.com/mozilla/PyHive.git
Add Hive backend, refactor common code, fix bugs
This commit is contained in:
Родитель
e4c75c9af8
Коммит
e49f70fe93
|
@ -28,5 +28,5 @@ print select([func.count('*')], from_obj=user).scalar()
|
||||||
Requirements
|
Requirements
|
||||||
============
|
============
|
||||||
- Presto DBAPI: Just a Presto install
|
- Presto DBAPI: Just a Presto install
|
||||||
- Hive DBAPI: Thrift-generated `TCLIService` package
|
- Hive DBAPI: HiveServer2 daemon, `TCLIService`, `thrift`, `sasl`, `thrift_sasl`
|
||||||
- SQLAlchemy integration: `sqlalchemy` version 0.5.8
|
- SQLAlchemy integration: `sqlalchemy` version 0.5
|
||||||
|
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""
|
||||||
|
Package private common utilities. Do not use directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pyhive import exc
|
||||||
|
import abc
|
||||||
|
import collections
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class DBAPICursor(object):
|
||||||
|
"""Base class for some common DBAPI logic"""
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
_STATE_NONE = 0
|
||||||
|
_STATE_RUNNING = 1
|
||||||
|
_STATE_FINISHED = 2
|
||||||
|
|
||||||
|
def __init__(self, poll_interval=1):
|
||||||
|
self._poll_interval = poll_interval
|
||||||
|
self._reset_state()
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
"""Reset state about the previous query in preparation for running another query"""
|
||||||
|
# State to return as part of DBAPI
|
||||||
|
self._rownumber = 0
|
||||||
|
|
||||||
|
# Internal helper state
|
||||||
|
self._state = self._STATE_NONE
|
||||||
|
self._data = collections.deque()
|
||||||
|
self._columns = None
|
||||||
|
|
||||||
|
def _fetch_while(self, fn):
|
||||||
|
while fn():
|
||||||
|
self._fetch_more()
|
||||||
|
if fn():
|
||||||
|
time.sleep(self._poll_interval)
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def description(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""By default, do nothing"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _fetch_more(self):
|
||||||
|
"""Get more results, append it to ``self._data``, and update ``self._state``."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rowcount(self):
|
||||||
|
"""By default, return -1 to indicate that this is not supported."""
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def executemany(self, operation, seq_of_parameters):
|
||||||
|
"""Prepare a database operation (query or command) and then execute it against all parameter
|
||||||
|
sequences or mappings found in the sequence seq_of_parameters.
|
||||||
|
|
||||||
|
Only the final result set is retained.
|
||||||
|
|
||||||
|
Return values are not defined.
|
||||||
|
"""
|
||||||
|
for parameters in seq_of_parameters[:-1]:
|
||||||
|
self.execute(operation, parameters)
|
||||||
|
while self._state != self._STATE_FINISHED:
|
||||||
|
self._fetch_more()
|
||||||
|
if seq_of_parameters:
|
||||||
|
self.execute(operation, seq_of_parameters[-1])
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
"""Fetch the next row of a query result set, returning a single sequence, or None when no
|
||||||
|
more data is available.
|
||||||
|
|
||||||
|
An Error (or subclass) exception is raised if the previous call to execute() did not
|
||||||
|
produce any result set or no call was issued yet.
|
||||||
|
"""
|
||||||
|
if self._state == self._STATE_NONE:
|
||||||
|
raise exc.ProgrammingError('No query yet')
|
||||||
|
|
||||||
|
# Sleep until we're done or we have some data to return
|
||||||
|
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
|
||||||
|
|
||||||
|
if not self._data:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self._rownumber += 1
|
||||||
|
return self._data.popleft()
|
||||||
|
|
||||||
|
def fetchmany(self, size=None):
|
||||||
|
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
|
||||||
|
list of tuples). An empty sequence is returned when no more rows are available.
|
||||||
|
|
||||||
|
The number of rows to fetch per call is specified by the parameter. If it is not given, the
|
||||||
|
cursor's arraysize determines the number of rows to be fetched. The method should try to
|
||||||
|
fetch as many rows as indicated by the size parameter. If this is not possible due to the
|
||||||
|
specified number of rows not being available, fewer rows may be returned.
|
||||||
|
|
||||||
|
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
||||||
|
produce any result set or no call was issued yet.
|
||||||
|
"""
|
||||||
|
if size is None:
|
||||||
|
size = self.arraysize
|
||||||
|
result = []
|
||||||
|
for _ in xrange(size):
|
||||||
|
one = self.fetchone()
|
||||||
|
if one is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
result.append(one)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def fetchall(self):
|
||||||
|
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
|
||||||
|
(e.g. a list of tuples).
|
||||||
|
|
||||||
|
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
||||||
|
produce any result set or no call was issued yet.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
while True:
|
||||||
|
one = self.fetchone()
|
||||||
|
if one is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
result.append(one)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arraysize(self):
|
||||||
|
"""This read/write attribute specifies the number of rows to fetch at a time with
|
||||||
|
``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time.
|
||||||
|
"""
|
||||||
|
return self._arraysize
|
||||||
|
|
||||||
|
@arraysize.setter
|
||||||
|
def arraysize(self, value):
|
||||||
|
self._arraysize = value
|
||||||
|
|
||||||
|
def setinputsizes(self, sizes):
|
||||||
|
"""Does nothing by default"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setoutputsize(self, size, column=None):
|
||||||
|
"""Does nothing by default"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
#
|
||||||
|
# Optional DB API Extensions
|
||||||
|
#
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rownumber(self):
|
||||||
|
"""This read-only attribute should provide the current 0-based index of the cursor in the
|
||||||
|
result set.
|
||||||
|
|
||||||
|
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
|
||||||
|
operation will fetch the row indexed by .rownumber in that sequence.
|
||||||
|
"""
|
||||||
|
return self._rownumber
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
"""Return the next row from the currently executing SQL statement using the same semantics
|
||||||
|
as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted.
|
||||||
|
"""
|
||||||
|
one = self.fetchone()
|
||||||
|
if one is None:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
return one
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return self to make cursors compatible to the iteration protocol."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DBAPITypeObject(object):
|
||||||
|
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
|
||||||
|
def __init__(self, *values):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def __cmp__(self, other):
|
||||||
|
if other in self.values:
|
||||||
|
return 0
|
||||||
|
if other < self.values:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
class ParamEscaper(object):
|
||||||
|
def escape_args(self, parameters):
|
||||||
|
if isinstance(parameters, dict):
|
||||||
|
return {k: self.escape_item(v) for k, v in parameters.iteritems()}
|
||||||
|
elif isinstance(parameters, (list, tuple)):
|
||||||
|
return tuple(self.escape_item(x) for x in parameters)
|
||||||
|
else:
|
||||||
|
raise exc.ProgrammingError("Unsupported param format: {}".format(parameters))
|
||||||
|
|
||||||
|
def escape_number(self, item):
|
||||||
|
return item
|
||||||
|
|
||||||
|
def escape_string(self, item):
|
||||||
|
# This is good enough when backslashes are literal, newlines are just followed, and the way
|
||||||
|
# to escape a single quote is to put two single quotes.
|
||||||
|
# (i.e. only special character is single quote)
|
||||||
|
return "'{}'".format(item.replace("'", "''"))
|
||||||
|
|
||||||
|
def escape_item(self, item):
|
||||||
|
if isinstance(item, (int, long, float)):
|
||||||
|
return self.escape_number(item)
|
||||||
|
elif isinstance(item, basestring):
|
||||||
|
return self.escape_string(item)
|
||||||
|
else:
|
||||||
|
raise exc.ProgrammingError("Unsupported object {}".format(item))
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""
|
||||||
|
Package private common utilities. Do not use directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError',
|
||||||
|
'ProgrammingError', 'DataError', 'NotSupportedError',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Error(StandardError):
|
||||||
|
"""Exception that is the base class of all other error exceptions.
|
||||||
|
|
||||||
|
You can use this to catch all errors with one single except statement.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Warning(StandardError):
|
||||||
|
"""Exception raised for important warnings like data truncations while inserting, etc."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InterfaceError(Error):
|
||||||
|
"""Exception raised for errors that are related to the database interface rather than the
|
||||||
|
database itself.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(Error):
|
||||||
|
"""Exception raised for errors that are related to the database."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InternalError(DatabaseError):
|
||||||
|
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
|
||||||
|
anymore, the transaction is out of sync, etc."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OperationalError(DatabaseError):
|
||||||
|
"""Exception raised for errors that are related to the database's operation and not necessarily
|
||||||
|
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
|
||||||
|
is not found, a transaction could not be processed, a memory allocation error occurred during
|
||||||
|
processing, etc.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProgrammingError(DatabaseError):
|
||||||
|
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
|
||||||
|
in the SQL statement, wrong number of parameters specified, etc.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataError(DatabaseError):
|
||||||
|
"""Exception raised for errors that are due to problems with the processed data like division by
|
||||||
|
zero, numeric value out of range, etc.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedError(DatabaseError):
|
||||||
|
"""Exception raised in case a method or database API was used which is not supported by the
|
||||||
|
database, e.g. requesting a .rollback() on a connection that does not support transaction or
|
||||||
|
has transactions turned off.
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -0,0 +1,245 @@
|
||||||
|
"""DB API implementation backed by HiveServer2 (Thrift API)
|
||||||
|
|
||||||
|
See http://www.python.org/dev/peps/pep-0249/
|
||||||
|
|
||||||
|
Many docstrings in this file are based on the PEP, which is in the public domain.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from TCLIService import TCLIService
|
||||||
|
from TCLIService import constants
|
||||||
|
from TCLIService import ttypes
|
||||||
|
from pyhive import common
|
||||||
|
from pyhive.common import DBAPITypeObject
|
||||||
|
# Make all exceptions visible in this module per DBAPI
|
||||||
|
from pyhive.exc import *
|
||||||
|
import getpass
|
||||||
|
import logging
|
||||||
|
import sasl
|
||||||
|
import sys
|
||||||
|
import thrift.protocol.TBinaryProtocol
|
||||||
|
import thrift.transport.TSocket
|
||||||
|
import thrift_sasl
|
||||||
|
|
||||||
|
# PEP 249 module globals
|
||||||
|
apilevel = '2.0'
|
||||||
|
threadsafety = 2 # Threads may share the module and connections.
|
||||||
|
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HiveParamEscaper(common.ParamEscaper):
|
||||||
|
def escape_string(self, item):
|
||||||
|
# backslashes and single quotes need to be escaped
|
||||||
|
# TODO verify against parser
|
||||||
|
return "'{}'".format(item
|
||||||
|
.replace('\\', '\\\\')
|
||||||
|
.replace("'", "\\'")
|
||||||
|
.replace('\r', '\\r')
|
||||||
|
.replace('\n', '\\n')
|
||||||
|
.replace('\t', '\\t')
|
||||||
|
)
|
||||||
|
|
||||||
|
_escaper = HiveParamEscaper()
|
||||||
|
|
||||||
|
|
||||||
|
def connect(**kwargs):
|
||||||
|
"""Constructor for creating a connection to the database. See class Connection for arguments.
|
||||||
|
|
||||||
|
Returns a Connection object.
|
||||||
|
"""
|
||||||
|
return Connection(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Connection(object):
|
||||||
|
"""Wraps a Thrift session"""
|
||||||
|
|
||||||
|
def __init__(self, host, port=10000, username=None, configuration=None):
|
||||||
|
socket = thrift.transport.TSocket.TSocket(host, port)
|
||||||
|
username = username or getpass.getuser()
|
||||||
|
configuration = configuration or {}
|
||||||
|
|
||||||
|
def sasl_factory():
|
||||||
|
sasl_client = sasl.Client()
|
||||||
|
sasl_client.setAttr('username', username)
|
||||||
|
# Password doesn't matter in PLAIN mode, just needs to be nonempty.
|
||||||
|
sasl_client.setAttr('password', 'x')
|
||||||
|
sasl_client.init()
|
||||||
|
return sasl_client
|
||||||
|
|
||||||
|
# PLAIN corresponds to hive.server2.authentication=NONE in hive-site.xml
|
||||||
|
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, 'PLAIN', socket)
|
||||||
|
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
|
||||||
|
self._client = TCLIService.Client(protocol)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._transport.open()
|
||||||
|
open_session_req = ttypes.TOpenSessionReq(
|
||||||
|
client_protocol=ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1,
|
||||||
|
configuration=configuration,
|
||||||
|
)
|
||||||
|
response = self._client.OpenSession(open_session_req)
|
||||||
|
_check_status(response)
|
||||||
|
assert(response.sessionHandle is not None), "Expected a session from OpenSession"
|
||||||
|
self._sessionHandle = response.sessionHandle
|
||||||
|
assert(response.serverProtocolVersion == ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1), \
|
||||||
|
"Unable to handle protocol version {}".format(response.serverProtocolVersion)
|
||||||
|
except:
|
||||||
|
self._transport.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the underlying session and Thrift transport"""
|
||||||
|
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
|
||||||
|
response = self._client.CloseSession(req)
|
||||||
|
self._transport.close()
|
||||||
|
_check_status(response)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
"""Hive does not support transactions"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cursor(self):
|
||||||
|
"""Return a new Cursor object using the connection."""
|
||||||
|
return Cursor(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sessionHandle(self):
|
||||||
|
return self._sessionHandle
|
||||||
|
|
||||||
|
|
||||||
|
class Cursor(common.DBAPICursor):
|
||||||
|
"""These objects represent a database cursor, which is used to manage the context of a fetch
|
||||||
|
operation.
|
||||||
|
|
||||||
|
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
|
||||||
|
visible by other cursors or connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, connection):
|
||||||
|
super(Cursor, self).__init__()
|
||||||
|
self._connection = connection
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
"""Reset state about the previous query in preparation for running another query"""
|
||||||
|
super(Cursor, self)._reset_state()
|
||||||
|
self._description = None
|
||||||
|
self._operationHandle = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
"""This read-only attribute is a sequence of 7-item sequences.
|
||||||
|
|
||||||
|
Each of these sequences contains information describing one result column:
|
||||||
|
|
||||||
|
- name
|
||||||
|
- type_code
|
||||||
|
- display_size (None in current implementation)
|
||||||
|
- internal_size (None in current implementation)
|
||||||
|
- precision (None in current implementation)
|
||||||
|
- scale (None in current implementation)
|
||||||
|
- null_ok (always True in current implementation)
|
||||||
|
|
||||||
|
The type_code can be interpreted by comparing it to the Type Objects specified in the
|
||||||
|
section below.
|
||||||
|
"""
|
||||||
|
if self._operationHandle is None:
|
||||||
|
return None
|
||||||
|
if self._description is None:
|
||||||
|
req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
|
||||||
|
response = self._connection.client.GetResultSetMetadata(req)
|
||||||
|
_check_status(response)
|
||||||
|
columns = response.schema.columns
|
||||||
|
self._description = []
|
||||||
|
for col in columns:
|
||||||
|
primary_type_entry = col.typeDesc.types[0]
|
||||||
|
if primary_type_entry.primitiveEntry is None:
|
||||||
|
# All fancy stuff maps to string
|
||||||
|
type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
|
||||||
|
else:
|
||||||
|
type_id = primary_type_entry.primitiveEntry.type
|
||||||
|
type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
|
||||||
|
self._description.append((col.columnName, type_code, None, None, None, None, True))
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the operation handle"""
|
||||||
|
if self._operationHandle is not None:
|
||||||
|
request = ttypes.TCloseOperationReq(self._operationHandle)
|
||||||
|
response = self._connection.client.CloseOperation(request)
|
||||||
|
_check_status(response)
|
||||||
|
|
||||||
|
def execute(self, operation, parameters=None):
|
||||||
|
"""Prepare and execute a database operation (query or command).
|
||||||
|
|
||||||
|
Return values are not defined.
|
||||||
|
"""
|
||||||
|
if self._state == self._STATE_RUNNING:
|
||||||
|
raise ProgrammingError("Already running a query")
|
||||||
|
|
||||||
|
# Prepare statement
|
||||||
|
if parameters is None:
|
||||||
|
sql = operation
|
||||||
|
else:
|
||||||
|
sql = operation % _escaper.escape_args(parameters)
|
||||||
|
|
||||||
|
self._reset_state()
|
||||||
|
|
||||||
|
self._state = self._STATE_RUNNING
|
||||||
|
_logger.debug("Query: %s", sql)
|
||||||
|
|
||||||
|
req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, sql)
|
||||||
|
response = self._connection.client.ExecuteStatement(req)
|
||||||
|
_check_status(response)
|
||||||
|
self._operationHandle = response.operationHandle
|
||||||
|
|
||||||
|
def _fetch_more(self):
|
||||||
|
"""Send another TFetchResultsReq and update state"""
|
||||||
|
assert(self._state == self._STATE_RUNNING)
|
||||||
|
req = ttypes.TFetchResultsReq(
|
||||||
|
operationHandle=self._operationHandle,
|
||||||
|
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
|
||||||
|
maxRows=1000,
|
||||||
|
)
|
||||||
|
response = self._connection.client.FetchResults(req)
|
||||||
|
_check_status(response)
|
||||||
|
# response.hasMoreRows seems to always be False, so we instead check the number of rows
|
||||||
|
#if not response.hasMoreRows:
|
||||||
|
if not response.results.rows:
|
||||||
|
self._state = self._STATE_FINISHED
|
||||||
|
for row in response.results.rows:
|
||||||
|
self._data.append([_unwrap_col_val(val) for val in row.colVals])
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Type Objects and Constructors
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
for type_id in constants.PRIMITIVE_TYPES:
|
||||||
|
name = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
|
||||||
|
setattr(sys.modules[__name__], name, DBAPITypeObject([name]))
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Private utilities
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_col_val(val):
|
||||||
|
"""Return the raw value from a TColumnValue instance."""
|
||||||
|
for _, _, attr, _, _ in filter(None, ttypes.TColumnValue.thrift_spec):
|
||||||
|
val_obj = getattr(val, attr)
|
||||||
|
if val_obj:
|
||||||
|
return val_obj.value
|
||||||
|
raise DataError("Got empty column value {}".format(val))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_status(response):
|
||||||
|
"""Raise an OperationalError if the status is not success"""
|
||||||
|
if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
|
||||||
|
raise OperationalError(response)
|
264
pyhive/presto.py
264
pyhive/presto.py
|
@ -5,12 +5,13 @@ See http://www.python.org/dev/peps/pep-0249/
|
||||||
Many docstrings in this file are based on the PEP, which is in the public domain.
|
Many docstrings in this file are based on the PEP, which is in the public domain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import collections
|
from pyhive import common
|
||||||
import exceptions
|
from pyhive.common import DBAPITypeObject
|
||||||
|
# Make all exceptions visible in this module per DBAPI
|
||||||
|
from pyhive.exc import *
|
||||||
import getpass
|
import getpass
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
import time
|
|
||||||
import urlparse
|
import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,6 +21,7 @@ threadsafety = 2 # Threads may share the module and connections.
|
||||||
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
|
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
_escaper = common.ParamEscaper()
|
||||||
|
|
||||||
|
|
||||||
def connect(**kwargs):
|
def connect(**kwargs):
|
||||||
|
@ -53,16 +55,13 @@ class Connection(object):
|
||||||
return Cursor(**self._params)
|
return Cursor(**self._params)
|
||||||
|
|
||||||
|
|
||||||
class Cursor(object):
|
class Cursor(common.DBAPICursor):
|
||||||
"""These objects represent a database cursor, which is used to manage the context of a fetch
|
"""These objects represent a database cursor, which is used to manage the context of a fetch
|
||||||
operation.
|
operation.
|
||||||
|
|
||||||
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
|
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
|
||||||
visible by other cursors or connections.
|
visible by other cursors or connections.
|
||||||
"""
|
"""
|
||||||
_STATE_NONE = 0
|
|
||||||
_STATE_RUNNING = 1
|
|
||||||
_STATE_FINISHED = 2
|
|
||||||
|
|
||||||
def __init__(self, host, port='8080', user=None, catalog='hive', schema='default',
|
def __init__(self, host, port='8080', user=None, catalog='hive', schema='default',
|
||||||
poll_interval=1, source='pyhive'):
|
poll_interval=1, source='pyhive'):
|
||||||
|
@ -76,6 +75,7 @@ class Cursor(object):
|
||||||
update, defaults to a second
|
update, defaults to a second
|
||||||
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
|
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
|
||||||
"""
|
"""
|
||||||
|
super(Cursor, self).__init__(poll_interval)
|
||||||
# Config
|
# Config
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
@ -90,21 +90,10 @@ class Cursor(object):
|
||||||
|
|
||||||
def _reset_state(self):
|
def _reset_state(self):
|
||||||
"""Reset state about the previous query in preparation for running another query"""
|
"""Reset state about the previous query in preparation for running another query"""
|
||||||
# State to return as part of DBAPI
|
super(Cursor, self)._reset_state()
|
||||||
self._rownumber = 0
|
|
||||||
|
|
||||||
# Internal helper state
|
|
||||||
self._state = self._STATE_NONE
|
|
||||||
self._nextUri = None
|
self._nextUri = None
|
||||||
self._data = collections.deque()
|
|
||||||
self._columns = None
|
self._columns = None
|
||||||
|
|
||||||
def _fetch_while(self, fn):
|
|
||||||
while fn():
|
|
||||||
self._fetch_more()
|
|
||||||
if fn():
|
|
||||||
time.sleep(self._poll_interval)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
"""This read-only attribute is a sequence of 7-item sequences.
|
"""This read-only attribute is a sequence of 7-item sequences.
|
||||||
|
@ -135,15 +124,6 @@ class Cursor(object):
|
||||||
for col in self._columns
|
for col in self._columns
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
|
||||||
def rowcount(self):
|
|
||||||
"""Presto does not support rowcount"""
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Presto does not have anything to close"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def execute(self, operation, parameters=None):
|
def execute(self, operation, parameters=None):
|
||||||
"""Prepare and execute a database operation (query or command).
|
"""Prepare and execute a database operation (query or command).
|
||||||
|
|
||||||
|
@ -163,7 +143,7 @@ class Cursor(object):
|
||||||
if parameters is None:
|
if parameters is None:
|
||||||
sql = operation
|
sql = operation
|
||||||
else:
|
else:
|
||||||
sql = operation % _escape_args(parameters)
|
sql = operation % _escaper.escape_args(parameters)
|
||||||
|
|
||||||
self._reset_state()
|
self._reset_state()
|
||||||
|
|
||||||
|
@ -176,7 +156,7 @@ class Cursor(object):
|
||||||
self._process_response(response)
|
self._process_response(response)
|
||||||
|
|
||||||
def _fetch_more(self):
|
def _fetch_more(self):
|
||||||
"""Fetch the next URI and udpate state"""
|
"""Fetch the next URI and update state"""
|
||||||
self._process_response(requests.get(self._nextUri))
|
self._process_response(requests.get(self._nextUri))
|
||||||
|
|
||||||
def _process_response(self, response):
|
def _process_response(self, response):
|
||||||
|
@ -199,238 +179,14 @@ class Cursor(object):
|
||||||
assert not self._nextUri, 'Should not have nextUri if failed'
|
assert not self._nextUri, 'Should not have nextUri if failed'
|
||||||
raise DatabaseError(response_json['error'])
|
raise DatabaseError(response_json['error'])
|
||||||
|
|
||||||
def executemany(self, operation, seq_of_parameters):
|
|
||||||
"""Prepare a database operation (query or command) and then execute it against all parameter
|
|
||||||
sequences or mappings found in the sequence seq_of_parameters.
|
|
||||||
|
|
||||||
Only the final result set is retained.
|
|
||||||
|
|
||||||
Return values are not defined.
|
|
||||||
"""
|
|
||||||
for parameters in seq_of_parameters[:-1]:
|
|
||||||
self.execute(operation, parameters)
|
|
||||||
while self._state != self._STATE_FINISHED:
|
|
||||||
self._fetch_more()
|
|
||||||
self.execute(operation, seq_of_parameters[-1])
|
|
||||||
|
|
||||||
def fetchone(self):
|
|
||||||
"""Fetch the next row of a query result set, returning a single sequence, or None when no
|
|
||||||
more data is available.
|
|
||||||
|
|
||||||
An Error (or subclass) exception is raised if the previous call to execute() did not
|
|
||||||
produce any result set or no call was issued yet.
|
|
||||||
"""
|
|
||||||
if self._state == self._STATE_NONE:
|
|
||||||
raise ProgrammingError('No query yet')
|
|
||||||
# Note: all Presto statements produce a result set
|
|
||||||
# The CREATE TABLE statement produces a single bigint called 'rows'
|
|
||||||
|
|
||||||
# Sleep until we're done or we have some data to return
|
|
||||||
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
|
|
||||||
|
|
||||||
if not self._data:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
self._rownumber += 1
|
|
||||||
return self._data.popleft()
|
|
||||||
|
|
||||||
def fetchmany(self, size=None):
|
|
||||||
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
|
|
||||||
list of tuples). An empty sequence is returned when no more rows are available.
|
|
||||||
|
|
||||||
The number of rows to fetch per call is specified by the parameter. If it is not given, the
|
|
||||||
cursor's arraysize determines the number of rows to be fetched. The method should try to
|
|
||||||
fetch as many rows as indicated by the size parameter. If this is not possible due to the
|
|
||||||
specified number of rows not being available, fewer rows may be returned.
|
|
||||||
|
|
||||||
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
|
||||||
produce any result set or no call was issued yet.
|
|
||||||
"""
|
|
||||||
if size is None:
|
|
||||||
size = self.arraysize
|
|
||||||
result = []
|
|
||||||
for _ in xrange(size):
|
|
||||||
one = self.fetchone()
|
|
||||||
if one is None:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
result.append(one)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def fetchall(self):
|
|
||||||
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
|
|
||||||
(e.g. a list of tuples).
|
|
||||||
|
|
||||||
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
|
||||||
produce any result set or no call was issued yet.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
while True:
|
|
||||||
one = self.fetchone()
|
|
||||||
if one is None:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
result.append(one)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@property
|
|
||||||
def arraysize(self):
|
|
||||||
"""This read/write attribute specifies the number of rows to fetch at a time with
|
|
||||||
``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time.
|
|
||||||
|
|
||||||
In our current implementation this parameter has no effect on actual fetching.
|
|
||||||
"""
|
|
||||||
return self._arraysize
|
|
||||||
|
|
||||||
@arraysize.setter
|
|
||||||
def arraysize(self, value):
|
|
||||||
self._arraysize = value
|
|
||||||
|
|
||||||
def setinputsizes(self, sizes):
|
|
||||||
"""Does nothing for Presto"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setoutputsize(self, size, column=None):
|
|
||||||
"""Does nothing for Presto"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
#
|
|
||||||
# Optional DB API Extensions
|
|
||||||
#
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rownumber(self):
|
|
||||||
"""This read-only attribute should provide the current 0-based index of the cursor in the
|
|
||||||
result set.
|
|
||||||
|
|
||||||
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
|
|
||||||
operation will fetch the row indexed by .rownumber in that sequence.
|
|
||||||
"""
|
|
||||||
return self._rownumber
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
"""Return the next row from the currently executing SQL statement using the same semantics
|
|
||||||
as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted.
|
|
||||||
"""
|
|
||||||
one = self.fetchone()
|
|
||||||
if one is None:
|
|
||||||
raise StopIteration
|
|
||||||
else:
|
|
||||||
return one
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""Return self to make cursors compatible to the iteration protocol."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
#
|
|
||||||
# Exceptions
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class Error(exceptions.StandardError):
|
|
||||||
"""Exception that is the base class of all other error exceptions.
|
|
||||||
|
|
||||||
You can use this to catch all errors with one single except statement.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Warning(exceptions.StandardError):
|
|
||||||
"""Exception raised for important warnings like data truncations while inserting, etc."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InterfaceError(Error):
|
|
||||||
"""Exception raised for errors that are related to the database interface rather than the
|
|
||||||
database itself.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseError(Error):
|
|
||||||
"""Exception raised for errors that are related to the database."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InternalError(DatabaseError):
|
|
||||||
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
|
|
||||||
anymore, the transaction is out of sync, etc."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OperationalError(DatabaseError):
|
|
||||||
"""Exception raised for errors that are related to the database's operation and not necessarily
|
|
||||||
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
|
|
||||||
is not found, a transaction could not be processed, a memory allocation error occurred during
|
|
||||||
processing, etc.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProgrammingError(DatabaseError):
|
|
||||||
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
|
|
||||||
in the SQL statement, wrong number of parameters specified, etc.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DataError(DatabaseError):
|
|
||||||
"""Exception raised for errors that are due to problems with the processed data like division by
|
|
||||||
zero, numeric value out of range, etc.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NotSupportedError(DatabaseError):
|
|
||||||
"""Exception raised in case a method or database API was used which is not supported by the
|
|
||||||
database, e.g. requesting a .rollback() on a connection that does not support transaction or
|
|
||||||
has transactions turned off.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Type Objects and Constructors
|
# Type Objects and Constructors
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
class DBAPITypeObject(object):
|
|
||||||
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
|
|
||||||
def __init__(self, *values):
|
|
||||||
self.values = values
|
|
||||||
|
|
||||||
def __cmp__(self, other):
|
|
||||||
if other in self.values:
|
|
||||||
return 0
|
|
||||||
if other < self.values:
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
# See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java
|
# See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java
|
||||||
FIXED_INT_64 = DBAPITypeObject(['bigint'])
|
FIXED_INT_64 = DBAPITypeObject(['bigint'])
|
||||||
VARIABLE_BINARY = DBAPITypeObject(['varchar'])
|
VARIABLE_BINARY = DBAPITypeObject(['varchar'])
|
||||||
DOUBLE = DBAPITypeObject(['double'])
|
DOUBLE = DBAPITypeObject(['double'])
|
||||||
BOOLEAN = DBAPITypeObject(['boolean'])
|
BOOLEAN = DBAPITypeObject(['boolean'])
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Private utilities
|
|
||||||
#
|
|
||||||
def _escape_args(parameters):
|
|
||||||
if isinstance(parameters, dict):
|
|
||||||
return {k: _escape_item(v) for k, v in parameters.iteritems()}
|
|
||||||
elif isinstance(parameters, (list, tuple)):
|
|
||||||
return tuple(_escape_item(x) for x in parameters)
|
|
||||||
else:
|
|
||||||
raise ProgrammingError("Unsupported param format: {}".format(parameters))
|
|
||||||
|
|
||||||
|
|
||||||
def _escape_item(item):
|
|
||||||
if isinstance(item, (int, long, float)):
|
|
||||||
return item
|
|
||||||
elif isinstance(item, basestring):
|
|
||||||
# TODO is this good enough?
|
|
||||||
return "'{}'".format(item.replace("'", "''"))
|
|
||||||
else:
|
|
||||||
raise ProgrammingError("Unsupported object {}".format(item))
|
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
"""
|
||||||
|
Shared DBAPI test cases
|
||||||
|
"""
|
||||||
|
from pyhive import exc
|
||||||
|
import abc
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
def with_cursor(fn):
|
||||||
|
"""Pass a cursor to the given function and handle cleanup.
|
||||||
|
|
||||||
|
The cursor is taken from ``self.connect()``.
|
||||||
|
"""
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapped_fn(self, *args, **kwargs):
|
||||||
|
with contextlib.closing(self.connect()) as connection:
|
||||||
|
with contextlib.closing(connection.cursor()) as cursor:
|
||||||
|
fn(self, cursor, *args, **kwargs)
|
||||||
|
return wrapped_fn
|
||||||
|
|
||||||
|
|
||||||
|
class DBAPITestCase(unittest.TestCase):
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def connect(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_fetchone(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM one_row')
|
||||||
|
self.assertEqual(cursor.rownumber, 0)
|
||||||
|
self.assertEqual(cursor.fetchone(), [1])
|
||||||
|
self.assertEqual(cursor.rownumber, 1)
|
||||||
|
self.assertIsNone(cursor.fetchone())
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_fetchall(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM one_row')
|
||||||
|
self.assertEqual(cursor.fetchall(), [[1]])
|
||||||
|
cursor.execute('SELECT a FROM many_rows ORDER BY a')
|
||||||
|
self.assertEqual(cursor.fetchall(), [[i] for i in xrange(10000)])
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_iterator(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM one_row')
|
||||||
|
self.assertEqual(list(cursor), [[1]])
|
||||||
|
self.assertRaises(StopIteration, cursor.next)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_description_initial(self, cursor):
|
||||||
|
self.assertIsNone(cursor.description)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_description_failed(self, cursor):
|
||||||
|
try:
|
||||||
|
cursor.execute('blah_blah')
|
||||||
|
except exc.DatabaseError:
|
||||||
|
pass
|
||||||
|
self.assertIsNone(cursor.description)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_bad_query(self, cursor):
|
||||||
|
def run():
|
||||||
|
cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist')
|
||||||
|
cursor.fetchone()
|
||||||
|
self.assertRaises(exc.DatabaseError, run)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_concurrent_execution(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM one_row')
|
||||||
|
self.assertRaises(exc.ProgrammingError,
|
||||||
|
lambda: cursor.execute('SELECT * FROM one_row'))
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_executemany(self, cursor):
|
||||||
|
for length in 1, 2:
|
||||||
|
cursor.executemany(
|
||||||
|
'SELECT %(x)d FROM one_row',
|
||||||
|
[{'x': i} for i in xrange(1, length + 1)]
|
||||||
|
)
|
||||||
|
self.assertEqual(cursor.fetchall(), [[length]])
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_executemany_none(self, cursor):
|
||||||
|
cursor.executemany('should_never_get_used', [])
|
||||||
|
self.assertIsNone(cursor.description)
|
||||||
|
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_fetchone_no_data(self, cursor):
|
||||||
|
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_fetchmany(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM many_rows LIMIT 15')
|
||||||
|
self.assertEqual(cursor.fetchmany(0), [])
|
||||||
|
self.assertEqual(len(cursor.fetchmany(10)), 10)
|
||||||
|
self.assertEqual(len(cursor.fetchmany(10)), 5)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_arraysize(self, cursor):
|
||||||
|
cursor.arraysize = 5
|
||||||
|
cursor.execute('SELECT * FROM many_rows LIMIT 20')
|
||||||
|
self.assertEqual(len(cursor.fetchmany()), 5)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_polling_loop(self, cursor):
|
||||||
|
"""Try to trigger the polling logic in fetchone()"""
|
||||||
|
cursor._poll_interval = 0
|
||||||
|
cursor.execute('SELECT COUNT(*) FROM many_rows')
|
||||||
|
self.assertEqual(cursor.fetchone(), [10000])
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_no_params(self, cursor):
|
||||||
|
cursor.execute("SELECT '%(x)s' FROM one_row")
|
||||||
|
self.assertEqual(cursor.fetchall(), [['%(x)s']])
|
||||||
|
|
||||||
|
def test_escape(self):
|
||||||
|
"""Verify that funny characters can be escaped as strings and SELECTed back"""
|
||||||
|
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t '''
|
||||||
|
self.run_escape_case(bad_str)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def run_escape_case(self, cursor, bad_str):
|
||||||
|
cursor.execute(
|
||||||
|
'SELECT %d, %s FROM one_row',
|
||||||
|
(1, bad_str)
|
||||||
|
)
|
||||||
|
self.assertEqual(cursor.fetchall(), [[1, bad_str]])
|
||||||
|
cursor.execute(
|
||||||
|
'SELECT %(a)d, %(b)s FROM one_row',
|
||||||
|
{'a': 1, 'b': bad_str}
|
||||||
|
)
|
||||||
|
self.assertEqual(cursor.fetchall(), [[1, bad_str]])
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_invalid_params(self, cursor):
|
||||||
|
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi'))
|
||||||
|
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [{}]))
|
||||||
|
|
||||||
|
def test_open_close(self):
|
||||||
|
with contextlib.closing(self.connect()):
|
||||||
|
pass
|
||||||
|
with contextlib.closing(self.connect()) as connection:
|
||||||
|
with contextlib.closing(connection.cursor()):
|
||||||
|
pass
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Hive integration tests.
|
||||||
|
|
||||||
|
These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
|
||||||
|
They also require a tables created by make_test_tables.sh.
|
||||||
|
"""
|
||||||
|
from TCLIService import ttypes
|
||||||
|
from pyhive import exc
|
||||||
|
from pyhive import hive
|
||||||
|
from pyhive.tests.dbapi_test_case import DBAPITestCase
|
||||||
|
import mock
|
||||||
|
import contextlib
|
||||||
|
from pyhive.tests.dbapi_test_case import with_cursor
|
||||||
|
|
||||||
|
_HOST = 'localhost'
|
||||||
|
|
||||||
|
|
||||||
|
class TestHive(DBAPITestCase):
|
||||||
|
__test__ = True
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
return hive.connect(host=_HOST, username='hadoop')
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_description(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM one_row')
|
||||||
|
desc = [('number_of_rows', 'INT_TYPE', None, None, None, None, True)]
|
||||||
|
self.assertEqual(cursor.description, desc)
|
||||||
|
self.assertEqual(cursor.description, desc)
|
||||||
|
|
||||||
|
@with_cursor
|
||||||
|
def test_complex(self, cursor):
|
||||||
|
cursor.execute('SELECT * FROM one_row_complex')
|
||||||
|
self.assertEqual(cursor.description, [
|
||||||
|
('a', 'STRING_TYPE', None, None, None, None, True),
|
||||||
|
('b', 'STRING_TYPE', None, None, None, None, True),
|
||||||
|
])
|
||||||
|
self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']])
|
||||||
|
|
||||||
|
def test_noops(self):
|
||||||
|
"""The DB-API specification requires that certain actions exist, even though they might not
|
||||||
|
be applicable."""
|
||||||
|
# Wohoo inflating coverage stats!
|
||||||
|
with contextlib.closing(self.connect()) as connection:
|
||||||
|
with contextlib.closing(connection.cursor()) as cursor:
|
||||||
|
self.assertEqual(cursor.rowcount, -1)
|
||||||
|
cursor.setinputsizes([])
|
||||||
|
cursor.setoutputsize(1, 'blah')
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
@mock.patch('TCLIService.TCLIService.Client.OpenSession')
|
||||||
|
def test_open_failed(self, open_session):
|
||||||
|
open_session.return_value.serverProtocolVersion = \
|
||||||
|
ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
|
||||||
|
self.assertRaises(exc.OperationalError, self.connect)
|
||||||
|
|
||||||
|
def test_escape(self):
|
||||||
|
# Hive thrift translates newlines into multiple rows. WTF.
|
||||||
|
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
|
||||||
|
self.run_escape_case(bad_str)
|
||||||
|
|
||||||
|
def test_newlines(self):
|
||||||
|
"""Verify that newlines are passed through in a way that doesn't fail parsing"""
|
||||||
|
# Hive thrift translates newlines into multiple rows. WTF.
|
||||||
|
cursor = self.connect().cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'SELECT %s FROM one_row',
|
||||||
|
(' \r\n \r \n ',)
|
||||||
|
)
|
||||||
|
self.assertEqual(cursor.fetchall(), [[' '], [' '], [' '], [' ']])
|
|
@ -1,137 +1,50 @@
|
||||||
"""Presto integration tests.
|
"""Presto integration tests.
|
||||||
|
|
||||||
These rely on having a Presto+Hadoop cluster set up. They also require a table called one_row.
|
These rely on having a Presto+Hadoop cluster set up.
|
||||||
|
They also require a tables created by make_test_tables.sh.
|
||||||
"""
|
"""
|
||||||
|
from pyhive import exc
|
||||||
from pyhive import presto
|
from pyhive import presto
|
||||||
|
from pyhive.tests.dbapi_test_case import DBAPITestCase
|
||||||
|
from pyhive.tests.dbapi_test_case import with_cursor
|
||||||
import mock
|
import mock
|
||||||
import unittest
|
|
||||||
|
|
||||||
_HOST = 'localhost'
|
_HOST = 'localhost'
|
||||||
_ONE_ROW_TABLE_NAME = 'one_row'
|
|
||||||
_BIG_TABLE_NAME = 'user'
|
|
||||||
|
|
||||||
|
|
||||||
class Testpresto(unittest.TestCase):
|
class TestPresto(DBAPITestCase):
|
||||||
def test_fetchone(self):
|
__test__ = True
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute('select 1 from {}'.format(_ONE_ROW_TABLE_NAME))
|
|
||||||
self.assertEqual(cursor.rownumber, 0)
|
|
||||||
self.assertEqual(cursor.fetchone(), [1])
|
|
||||||
self.assertEqual(cursor.rownumber, 1)
|
|
||||||
self.assertIsNone(cursor.fetchone())
|
|
||||||
|
|
||||||
def test_fetchall(self):
|
def connect(self):
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
return presto.connect(host=_HOST)
|
||||||
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
|
|
||||||
self.assertEqual(cursor.fetchall(), [[1]])
|
|
||||||
|
|
||||||
def test_iterator(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
|
|
||||||
self.assertEqual(list(cursor), [[1]])
|
|
||||||
self.assertRaises(StopIteration, lambda: cursor.next())
|
|
||||||
|
|
||||||
def test_description(self):
|
def test_description(self):
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
cursor = self.connect().cursor()
|
||||||
cursor.execute('select 1 as foobar from {}'.format(_ONE_ROW_TABLE_NAME))
|
cursor.execute('SELECT 1 AS foobar FROM one_row')
|
||||||
self.assertEqual(cursor.description, [('foobar', 'bigint', None, None, None, None, True)])
|
self.assertEqual(cursor.description, [('foobar', 'bigint', None, None, None, None, True)])
|
||||||
|
|
||||||
def test_description_initial(self):
|
@with_cursor
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
def test_complex(self, cursor):
|
||||||
self.assertIsNone(cursor.description)
|
cursor.execute('SELECT * FROM one_row_complex')
|
||||||
|
self.assertEqual(cursor.description, [
|
||||||
def test_description_failed(self):
|
('a', 'varchar', None, None, None, None, True),
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
('b', 'varchar', None, None, None, None, True),
|
||||||
try:
|
])
|
||||||
cursor.execute('blah_blah')
|
self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']])
|
||||||
except presto.DatabaseError:
|
|
||||||
pass
|
|
||||||
self.assertIsNone(cursor.description)
|
|
||||||
|
|
||||||
def test_bad_query(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
|
|
||||||
def run():
|
|
||||||
cursor.execute('select does_not_exist from {}'.format(_ONE_ROW_TABLE_NAME))
|
|
||||||
cursor.fetchone()
|
|
||||||
self.assertRaises(presto.DatabaseError, run)
|
|
||||||
|
|
||||||
def test_concurrent_execution(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
|
|
||||||
self.assertRaises(presto.ProgrammingError,
|
|
||||||
lambda: cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME)))
|
|
||||||
|
|
||||||
def test_executemany(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.executemany(
|
|
||||||
'select %(x)d from {}'.format(_ONE_ROW_TABLE_NAME),
|
|
||||||
[{'x': 1}, {'x': 2}, {'x': 3}]
|
|
||||||
)
|
|
||||||
self.assertEqual(cursor.fetchall(), [[3]])
|
|
||||||
|
|
||||||
def test_fetchone_no_data(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
self.assertRaises(presto.ProgrammingError, lambda: cursor.fetchone())
|
|
||||||
|
|
||||||
def test_fetchmany(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute('select * from {} limit 15'.format(_BIG_TABLE_NAME))
|
|
||||||
self.assertEqual(cursor.fetchmany(0), [])
|
|
||||||
self.assertEqual(len(cursor.fetchmany(10)), 10)
|
|
||||||
self.assertEqual(len(cursor.fetchmany(10)), 5)
|
|
||||||
|
|
||||||
def test_arraysize(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.arraysize = 5
|
|
||||||
cursor.execute('select * from {} limit 20'.format(_BIG_TABLE_NAME))
|
|
||||||
self.assertEqual(len(cursor.fetchmany()), 5)
|
|
||||||
|
|
||||||
def test_slow_query(self):
|
|
||||||
"""Trigger the polling logic in fetchone()"""
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute('select count(*) from {}'.format(_BIG_TABLE_NAME))
|
|
||||||
self.assertIsNotNone(cursor.fetchone())
|
|
||||||
|
|
||||||
def test_no_params(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute("select '%(x)s' from {}".format(_ONE_ROW_TABLE_NAME))
|
|
||||||
self.assertEqual(cursor.fetchall(), [['%(x)s']])
|
|
||||||
|
|
||||||
def test_noops(self):
|
def test_noops(self):
|
||||||
"""The DB-API specification requires that certain actions exist, even though they might not
|
"""The DB-API specification requires that certain actions exist, even though they might not
|
||||||
be applicable."""
|
be applicable."""
|
||||||
# Wohoo inflating coverage stats!
|
# Wohoo inflating coverage stats!
|
||||||
connection = presto.connect(host=_HOST)
|
connection = self.connect()
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
self.assertEqual(cursor.rowcount, -1)
|
self.assertEqual(cursor.rowcount, -1)
|
||||||
cursor.setinputsizes([])
|
cursor.setinputsizes([])
|
||||||
cursor.setoutputsize(1, 'blah')
|
cursor.setoutputsize(1, 'blah')
|
||||||
cursor.close()
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
connection.close()
|
|
||||||
|
|
||||||
def test_escape(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"select %d, %s from {}".format(_ONE_ROW_TABLE_NAME),
|
|
||||||
(1, "';")
|
|
||||||
)
|
|
||||||
self.assertEqual(cursor.fetchall(), [[1, "';"]])
|
|
||||||
cursor.execute(
|
|
||||||
"select %(a)d, %(b)s from {}".format(_ONE_ROW_TABLE_NAME),
|
|
||||||
{'a': 1, 'b': "';"}
|
|
||||||
)
|
|
||||||
self.assertEqual(cursor.fetchall(), [[1, "';"]])
|
|
||||||
|
|
||||||
def test_invalid_params(self):
|
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
|
||||||
self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', 'hi'))
|
|
||||||
self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', [{}]))
|
|
||||||
|
|
||||||
@mock.patch('requests.post')
|
@mock.patch('requests.post')
|
||||||
def test_non_200(self, post):
|
def test_non_200(self, post):
|
||||||
cursor = presto.connect(host=_HOST).cursor()
|
cursor = self.connect().cursor()
|
||||||
post.return_value.status_code = 404
|
post.return_value.status_code = 404
|
||||||
self.assertRaises(presto.OperationalError,
|
self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables'))
|
||||||
lambda: cursor.execute('show tables'))
|
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
hive -e 'DROP TABLE IF EXISTS many_rows'
|
||||||
|
hive -e "
|
||||||
|
CREATE TABLE many_rows (
|
||||||
|
a INT
|
||||||
|
) PARTITIONED BY (
|
||||||
|
b STRING
|
||||||
|
) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE"
|
||||||
|
|
||||||
|
temp_file=/tmp/pyhive_test_data_many_rows.tsv
|
||||||
|
seq 0 9999 > $temp_file
|
||||||
|
hive -e "LOAD DATA LOCAL INPATH '$temp_file' INTO TABLE many_rows PARTITION (b='blah')"
|
||||||
|
rm -f $temp_file
|
|
@ -0,0 +1,4 @@
|
||||||
|
#!/bin/bash -eux
|
||||||
|
hive -e 'DROP TABLE IF EXISTS one_row'
|
||||||
|
hive -e 'CREATE TABLE one_row (number_of_rows INT)'
|
||||||
|
hive -e 'INSERT OVERWRITE TABLE one_row SELECT COUNT(*) + 1 FROM one_row'
|
|
@ -0,0 +1,4 @@
|
||||||
|
#!/bin/bash -eux
|
||||||
|
hive -e 'DROP TABLE IF EXISTS one_row_complex'
|
||||||
|
hive -e 'CREATE TABLE one_row_complex (a map<INT, STRING>, b array<INT>)'
|
||||||
|
hive -e "INSERT OVERWRITE TABLE one_row_complex SELECT map(1, 'a', 2, 'b'), array(1, 2, 3) FROM one_row"
|
|
@ -0,0 +1,7 @@
|
||||||
|
#!/bin/bash -eux
|
||||||
|
# Hive must be on the path for this script to work.
|
||||||
|
# WARNING: drops and recreates tables called one_row, one_row_complex, and many_rows.
|
||||||
|
|
||||||
|
$(dirname $0)/make_one_row.sh
|
||||||
|
$(dirname $0)/make_one_row_complex.sh
|
||||||
|
$(dirname $0)/make_many_rows.sh
|
Загрузка…
Ссылка в новой задаче