зеркало из https://github.com/mozilla/PyHive.git
Add Hive backend, refactor common code, fix bugs
This commit is contained in:
Родитель
e4c75c9af8
Коммит
e49f70fe93
|
@ -28,5 +28,5 @@ print select([func.count('*')], from_obj=user).scalar()
|
|||
Requirements
|
||||
============
|
||||
- Presto DBAPI: Just a Presto install
|
||||
- Hive DBAPI: Thrift-generated `TCLIService` package
|
||||
- SQLAlchemy integration: `sqlalchemy` version 0.5.8
|
||||
- Hive DBAPI: HiveServer2 daemon, `TCLIService`, `thrift`, `sasl`, `thrift_sasl`
|
||||
- SQLAlchemy integration: `sqlalchemy` version 0.5
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
Package private common utilities. Do not use directly.
|
||||
"""
|
||||
|
||||
from pyhive import exc
|
||||
import abc
|
||||
import collections
|
||||
import time
|
||||
|
||||
|
||||
class DBAPICursor(object):
|
||||
"""Base class for some common DBAPI logic"""
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
_STATE_NONE = 0
|
||||
_STATE_RUNNING = 1
|
||||
_STATE_FINISHED = 2
|
||||
|
||||
def __init__(self, poll_interval=1):
|
||||
self._poll_interval = poll_interval
|
||||
self._reset_state()
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset state about the previous query in preparation for running another query"""
|
||||
# State to return as part of DBAPI
|
||||
self._rownumber = 0
|
||||
|
||||
# Internal helper state
|
||||
self._state = self._STATE_NONE
|
||||
self._data = collections.deque()
|
||||
self._columns = None
|
||||
|
||||
def _fetch_while(self, fn):
|
||||
while fn():
|
||||
self._fetch_more()
|
||||
if fn():
|
||||
time.sleep(self._poll_interval)
|
||||
|
||||
@abc.abstractproperty
|
||||
def description(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""By default, do nothing"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _fetch_more(self):
|
||||
"""Get more results, append it to ``self._data``, and update ``self._state``."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
"""By default, return -1 to indicate that this is not supported."""
|
||||
return -1
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
"""Prepare a database operation (query or command) and then execute it against all parameter
|
||||
sequences or mappings found in the sequence seq_of_parameters.
|
||||
|
||||
Only the final result set is retained.
|
||||
|
||||
Return values are not defined.
|
||||
"""
|
||||
for parameters in seq_of_parameters[:-1]:
|
||||
self.execute(operation, parameters)
|
||||
while self._state != self._STATE_FINISHED:
|
||||
self._fetch_more()
|
||||
if seq_of_parameters:
|
||||
self.execute(operation, seq_of_parameters[-1])
|
||||
|
||||
def fetchone(self):
|
||||
"""Fetch the next row of a query result set, returning a single sequence, or None when no
|
||||
more data is available.
|
||||
|
||||
An Error (or subclass) exception is raised if the previous call to execute() did not
|
||||
produce any result set or no call was issued yet.
|
||||
"""
|
||||
if self._state == self._STATE_NONE:
|
||||
raise exc.ProgrammingError('No query yet')
|
||||
|
||||
# Sleep until we're done or we have some data to return
|
||||
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
|
||||
|
||||
if not self._data:
|
||||
return None
|
||||
else:
|
||||
self._rownumber += 1
|
||||
return self._data.popleft()
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
|
||||
list of tuples). An empty sequence is returned when no more rows are available.
|
||||
|
||||
The number of rows to fetch per call is specified by the parameter. If it is not given, the
|
||||
cursor's arraysize determines the number of rows to be fetched. The method should try to
|
||||
fetch as many rows as indicated by the size parameter. If this is not possible due to the
|
||||
specified number of rows not being available, fewer rows may be returned.
|
||||
|
||||
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
||||
produce any result set or no call was issued yet.
|
||||
"""
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
result = []
|
||||
for _ in xrange(size):
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
break
|
||||
else:
|
||||
result.append(one)
|
||||
return result
|
||||
|
||||
def fetchall(self):
|
||||
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
|
||||
(e.g. a list of tuples).
|
||||
|
||||
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
||||
produce any result set or no call was issued yet.
|
||||
"""
|
||||
result = []
|
||||
while True:
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
break
|
||||
else:
|
||||
result.append(one)
|
||||
return result
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
"""This read/write attribute specifies the number of rows to fetch at a time with
|
||||
``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time.
|
||||
"""
|
||||
return self._arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._arraysize = value
|
||||
|
||||
def setinputsizes(self, sizes):
|
||||
"""Does nothing by default"""
|
||||
pass
|
||||
|
||||
def setoutputsize(self, size, column=None):
|
||||
"""Does nothing by default"""
|
||||
pass
|
||||
|
||||
#
|
||||
# Optional DB API Extensions
|
||||
#
|
||||
|
||||
@property
|
||||
def rownumber(self):
|
||||
"""This read-only attribute should provide the current 0-based index of the cursor in the
|
||||
result set.
|
||||
|
||||
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
|
||||
operation will fetch the row indexed by .rownumber in that sequence.
|
||||
"""
|
||||
return self._rownumber
|
||||
|
||||
def next(self):
|
||||
"""Return the next row from the currently executing SQL statement using the same semantics
|
||||
as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted.
|
||||
"""
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
raise StopIteration
|
||||
else:
|
||||
return one
|
||||
|
||||
def __iter__(self):
|
||||
"""Return self to make cursors compatible to the iteration protocol."""
|
||||
return self
|
||||
|
||||
|
||||
class DBAPITypeObject(object):
|
||||
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
|
||||
def __init__(self, *values):
|
||||
self.values = values
|
||||
|
||||
def __cmp__(self, other):
|
||||
if other in self.values:
|
||||
return 0
|
||||
if other < self.values:
|
||||
return 1
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class ParamEscaper(object):
|
||||
def escape_args(self, parameters):
|
||||
if isinstance(parameters, dict):
|
||||
return {k: self.escape_item(v) for k, v in parameters.iteritems()}
|
||||
elif isinstance(parameters, (list, tuple)):
|
||||
return tuple(self.escape_item(x) for x in parameters)
|
||||
else:
|
||||
raise exc.ProgrammingError("Unsupported param format: {}".format(parameters))
|
||||
|
||||
def escape_number(self, item):
|
||||
return item
|
||||
|
||||
def escape_string(self, item):
|
||||
# This is good enough when backslashes are literal, newlines are just followed, and the way
|
||||
# to escape a single quote is to put two single quotes.
|
||||
# (i.e. only special character is single quote)
|
||||
return "'{}'".format(item.replace("'", "''"))
|
||||
|
||||
def escape_item(self, item):
|
||||
if isinstance(item, (int, long, float)):
|
||||
return self.escape_number(item)
|
||||
elif isinstance(item, basestring):
|
||||
return self.escape_string(item)
|
||||
else:
|
||||
raise exc.ProgrammingError("Unsupported object {}".format(item))
|
|
@ -0,0 +1,70 @@
|
|||
"""
|
||||
Package private common utilities. Do not use directly.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError',
|
||||
'ProgrammingError', 'DataError', 'NotSupportedError',
|
||||
]
|
||||
|
||||
|
||||
class Error(StandardError):
|
||||
"""Exception that is the base class of all other error exceptions.
|
||||
|
||||
You can use this to catch all errors with one single except statement.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Warning(StandardError):
|
||||
"""Exception raised for important warnings like data truncations while inserting, etc."""
|
||||
pass
|
||||
|
||||
|
||||
class InterfaceError(Error):
|
||||
"""Exception raised for errors that are related to the database interface rather than the
|
||||
database itself.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(Error):
|
||||
"""Exception raised for errors that are related to the database."""
|
||||
pass
|
||||
|
||||
|
||||
class InternalError(DatabaseError):
|
||||
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
|
||||
anymore, the transaction is out of sync, etc."""
|
||||
pass
|
||||
|
||||
|
||||
class OperationalError(DatabaseError):
|
||||
"""Exception raised for errors that are related to the database's operation and not necessarily
|
||||
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
|
||||
is not found, a transaction could not be processed, a memory allocation error occurred during
|
||||
processing, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProgrammingError(DatabaseError):
|
||||
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
|
||||
in the SQL statement, wrong number of parameters specified, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DataError(DatabaseError):
|
||||
"""Exception raised for errors that are due to problems with the processed data like division by
|
||||
zero, numeric value out of range, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NotSupportedError(DatabaseError):
|
||||
"""Exception raised in case a method or database API was used which is not supported by the
|
||||
database, e.g. requesting a .rollback() on a connection that does not support transaction or
|
||||
has transactions turned off.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,245 @@
|
|||
"""DB API implementation backed by HiveServer2 (Thrift API)
|
||||
|
||||
See http://www.python.org/dev/peps/pep-0249/
|
||||
|
||||
Many docstrings in this file are based on the PEP, which is in the public domain.
|
||||
"""
|
||||
|
||||
from TCLIService import TCLIService
|
||||
from TCLIService import constants
|
||||
from TCLIService import ttypes
|
||||
from pyhive import common
|
||||
from pyhive.common import DBAPITypeObject
|
||||
# Make all exceptions visible in this module per DBAPI
|
||||
from pyhive.exc import *
|
||||
import getpass
|
||||
import logging
|
||||
import sasl
|
||||
import sys
|
||||
import thrift.protocol.TBinaryProtocol
|
||||
import thrift.transport.TSocket
|
||||
import thrift_sasl
|
||||
|
||||
# PEP 249 module globals
|
||||
apilevel = '2.0'
|
||||
threadsafety = 2 # Threads may share the module and connections.
|
||||
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HiveParamEscaper(common.ParamEscaper):
|
||||
def escape_string(self, item):
|
||||
# backslashes and single quotes need to be escaped
|
||||
# TODO verify against parser
|
||||
return "'{}'".format(item
|
||||
.replace('\\', '\\\\')
|
||||
.replace("'", "\\'")
|
||||
.replace('\r', '\\r')
|
||||
.replace('\n', '\\n')
|
||||
.replace('\t', '\\t')
|
||||
)
|
||||
|
||||
_escaper = HiveParamEscaper()
|
||||
|
||||
|
||||
def connect(**kwargs):
|
||||
"""Constructor for creating a connection to the database. See class Connection for arguments.
|
||||
|
||||
Returns a Connection object.
|
||||
"""
|
||||
return Connection(**kwargs)
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""Wraps a Thrift session"""
|
||||
|
||||
def __init__(self, host, port=10000, username=None, configuration=None):
|
||||
socket = thrift.transport.TSocket.TSocket(host, port)
|
||||
username = username or getpass.getuser()
|
||||
configuration = configuration or {}
|
||||
|
||||
def sasl_factory():
|
||||
sasl_client = sasl.Client()
|
||||
sasl_client.setAttr('username', username)
|
||||
# Password doesn't matter in PLAIN mode, just needs to be nonempty.
|
||||
sasl_client.setAttr('password', 'x')
|
||||
sasl_client.init()
|
||||
return sasl_client
|
||||
|
||||
# PLAIN corresponds to hive.server2.authentication=NONE in hive-site.xml
|
||||
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, 'PLAIN', socket)
|
||||
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
|
||||
self._client = TCLIService.Client(protocol)
|
||||
|
||||
try:
|
||||
self._transport.open()
|
||||
open_session_req = ttypes.TOpenSessionReq(
|
||||
client_protocol=ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1,
|
||||
configuration=configuration,
|
||||
)
|
||||
response = self._client.OpenSession(open_session_req)
|
||||
_check_status(response)
|
||||
assert(response.sessionHandle is not None), "Expected a session from OpenSession"
|
||||
self._sessionHandle = response.sessionHandle
|
||||
assert(response.serverProtocolVersion == ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1), \
|
||||
"Unable to handle protocol version {}".format(response.serverProtocolVersion)
|
||||
except:
|
||||
self._transport.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Close the underlying session and Thrift transport"""
|
||||
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
|
||||
response = self._client.CloseSession(req)
|
||||
self._transport.close()
|
||||
_check_status(response)
|
||||
|
||||
def commit(self):
|
||||
"""Hive does not support transactions"""
|
||||
pass
|
||||
|
||||
def cursor(self):
|
||||
"""Return a new Cursor object using the connection."""
|
||||
return Cursor(self)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sessionHandle(self):
|
||||
return self._sessionHandle
|
||||
|
||||
|
||||
class Cursor(common.DBAPICursor):
|
||||
"""These objects represent a database cursor, which is used to manage the context of a fetch
|
||||
operation.
|
||||
|
||||
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
|
||||
visible by other cursors or connections.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
super(Cursor, self).__init__()
|
||||
self._connection = connection
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset state about the previous query in preparation for running another query"""
|
||||
super(Cursor, self)._reset_state()
|
||||
self._description = None
|
||||
self._operationHandle = None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
"""This read-only attribute is a sequence of 7-item sequences.
|
||||
|
||||
Each of these sequences contains information describing one result column:
|
||||
|
||||
- name
|
||||
- type_code
|
||||
- display_size (None in current implementation)
|
||||
- internal_size (None in current implementation)
|
||||
- precision (None in current implementation)
|
||||
- scale (None in current implementation)
|
||||
- null_ok (always True in current implementation)
|
||||
|
||||
The type_code can be interpreted by comparing it to the Type Objects specified in the
|
||||
section below.
|
||||
"""
|
||||
if self._operationHandle is None:
|
||||
return None
|
||||
if self._description is None:
|
||||
req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
|
||||
response = self._connection.client.GetResultSetMetadata(req)
|
||||
_check_status(response)
|
||||
columns = response.schema.columns
|
||||
self._description = []
|
||||
for col in columns:
|
||||
primary_type_entry = col.typeDesc.types[0]
|
||||
if primary_type_entry.primitiveEntry is None:
|
||||
# All fancy stuff maps to string
|
||||
type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
|
||||
else:
|
||||
type_id = primary_type_entry.primitiveEntry.type
|
||||
type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
|
||||
self._description.append((col.columnName, type_code, None, None, None, None, True))
|
||||
return self._description
|
||||
|
||||
def close(self):
|
||||
"""Close the operation handle"""
|
||||
if self._operationHandle is not None:
|
||||
request = ttypes.TCloseOperationReq(self._operationHandle)
|
||||
response = self._connection.client.CloseOperation(request)
|
||||
_check_status(response)
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
"""Prepare and execute a database operation (query or command).
|
||||
|
||||
Return values are not defined.
|
||||
"""
|
||||
if self._state == self._STATE_RUNNING:
|
||||
raise ProgrammingError("Already running a query")
|
||||
|
||||
# Prepare statement
|
||||
if parameters is None:
|
||||
sql = operation
|
||||
else:
|
||||
sql = operation % _escaper.escape_args(parameters)
|
||||
|
||||
self._reset_state()
|
||||
|
||||
self._state = self._STATE_RUNNING
|
||||
_logger.debug("Query: %s", sql)
|
||||
|
||||
req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, sql)
|
||||
response = self._connection.client.ExecuteStatement(req)
|
||||
_check_status(response)
|
||||
self._operationHandle = response.operationHandle
|
||||
|
||||
def _fetch_more(self):
|
||||
"""Send another TFetchResultsReq and update state"""
|
||||
assert(self._state == self._STATE_RUNNING)
|
||||
req = ttypes.TFetchResultsReq(
|
||||
operationHandle=self._operationHandle,
|
||||
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
|
||||
maxRows=1000,
|
||||
)
|
||||
response = self._connection.client.FetchResults(req)
|
||||
_check_status(response)
|
||||
# response.hasMoreRows seems to always be False, so we instead check the number of rows
|
||||
#if not response.hasMoreRows:
|
||||
if not response.results.rows:
|
||||
self._state = self._STATE_FINISHED
|
||||
for row in response.results.rows:
|
||||
self._data.append([_unwrap_col_val(val) for val in row.colVals])
|
||||
|
||||
|
||||
#
|
||||
# Type Objects and Constructors
|
||||
#
|
||||
|
||||
|
||||
for type_id in constants.PRIMITIVE_TYPES:
|
||||
name = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
|
||||
setattr(sys.modules[__name__], name, DBAPITypeObject([name]))
|
||||
|
||||
|
||||
#
|
||||
# Private utilities
|
||||
#
|
||||
|
||||
|
||||
def _unwrap_col_val(val):
|
||||
"""Return the raw value from a TColumnValue instance."""
|
||||
for _, _, attr, _, _ in filter(None, ttypes.TColumnValue.thrift_spec):
|
||||
val_obj = getattr(val, attr)
|
||||
if val_obj:
|
||||
return val_obj.value
|
||||
raise DataError("Got empty column value {}".format(val))
|
||||
|
||||
|
||||
def _check_status(response):
|
||||
"""Raise an OperationalError if the status is not success"""
|
||||
if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
|
||||
raise OperationalError(response)
|
264
pyhive/presto.py
264
pyhive/presto.py
|
@ -5,12 +5,13 @@ See http://www.python.org/dev/peps/pep-0249/
|
|||
Many docstrings in this file are based on the PEP, which is in the public domain.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import exceptions
|
||||
from pyhive import common
|
||||
from pyhive.common import DBAPITypeObject
|
||||
# Make all exceptions visible in this module per DBAPI
|
||||
from pyhive.exc import *
|
||||
import getpass
|
||||
import logging
|
||||
import requests
|
||||
import time
|
||||
import urlparse
|
||||
|
||||
|
||||
|
@ -20,6 +21,7 @@ threadsafety = 2 # Threads may share the module and connections.
|
|||
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_escaper = common.ParamEscaper()
|
||||
|
||||
|
||||
def connect(**kwargs):
|
||||
|
@ -53,16 +55,13 @@ class Connection(object):
|
|||
return Cursor(**self._params)
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
class Cursor(common.DBAPICursor):
|
||||
"""These objects represent a database cursor, which is used to manage the context of a fetch
|
||||
operation.
|
||||
|
||||
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
|
||||
visible by other cursors or connections.
|
||||
"""
|
||||
_STATE_NONE = 0
|
||||
_STATE_RUNNING = 1
|
||||
_STATE_FINISHED = 2
|
||||
|
||||
def __init__(self, host, port='8080', user=None, catalog='hive', schema='default',
|
||||
poll_interval=1, source='pyhive'):
|
||||
|
@ -76,6 +75,7 @@ class Cursor(object):
|
|||
update, defaults to a second
|
||||
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
|
||||
"""
|
||||
super(Cursor, self).__init__(poll_interval)
|
||||
# Config
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
@ -90,21 +90,10 @@ class Cursor(object):
|
|||
|
||||
def _reset_state(self):
|
||||
"""Reset state about the previous query in preparation for running another query"""
|
||||
# State to return as part of DBAPI
|
||||
self._rownumber = 0
|
||||
|
||||
# Internal helper state
|
||||
self._state = self._STATE_NONE
|
||||
super(Cursor, self)._reset_state()
|
||||
self._nextUri = None
|
||||
self._data = collections.deque()
|
||||
self._columns = None
|
||||
|
||||
def _fetch_while(self, fn):
|
||||
while fn():
|
||||
self._fetch_more()
|
||||
if fn():
|
||||
time.sleep(self._poll_interval)
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
"""This read-only attribute is a sequence of 7-item sequences.
|
||||
|
@ -135,15 +124,6 @@ class Cursor(object):
|
|||
for col in self._columns
|
||||
]
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
"""Presto does not support rowcount"""
|
||||
return -1
|
||||
|
||||
def close(self):
|
||||
"""Presto does not have anything to close"""
|
||||
pass
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
"""Prepare and execute a database operation (query or command).
|
||||
|
||||
|
@ -163,7 +143,7 @@ class Cursor(object):
|
|||
if parameters is None:
|
||||
sql = operation
|
||||
else:
|
||||
sql = operation % _escape_args(parameters)
|
||||
sql = operation % _escaper.escape_args(parameters)
|
||||
|
||||
self._reset_state()
|
||||
|
||||
|
@ -176,7 +156,7 @@ class Cursor(object):
|
|||
self._process_response(response)
|
||||
|
||||
def _fetch_more(self):
|
||||
"""Fetch the next URI and udpate state"""
|
||||
"""Fetch the next URI and update state"""
|
||||
self._process_response(requests.get(self._nextUri))
|
||||
|
||||
def _process_response(self, response):
|
||||
|
@ -199,238 +179,14 @@ class Cursor(object):
|
|||
assert not self._nextUri, 'Should not have nextUri if failed'
|
||||
raise DatabaseError(response_json['error'])
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
"""Prepare a database operation (query or command) and then execute it against all parameter
|
||||
sequences or mappings found in the sequence seq_of_parameters.
|
||||
|
||||
Only the final result set is retained.
|
||||
|
||||
Return values are not defined.
|
||||
"""
|
||||
for parameters in seq_of_parameters[:-1]:
|
||||
self.execute(operation, parameters)
|
||||
while self._state != self._STATE_FINISHED:
|
||||
self._fetch_more()
|
||||
self.execute(operation, seq_of_parameters[-1])
|
||||
|
||||
def fetchone(self):
|
||||
"""Fetch the next row of a query result set, returning a single sequence, or None when no
|
||||
more data is available.
|
||||
|
||||
An Error (or subclass) exception is raised if the previous call to execute() did not
|
||||
produce any result set or no call was issued yet.
|
||||
"""
|
||||
if self._state == self._STATE_NONE:
|
||||
raise ProgrammingError('No query yet')
|
||||
# Note: all Presto statements produce a result set
|
||||
# The CREATE TABLE statement produces a single bigint called 'rows'
|
||||
|
||||
# Sleep until we're done or we have some data to return
|
||||
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
|
||||
|
||||
if not self._data:
|
||||
return None
|
||||
else:
|
||||
self._rownumber += 1
|
||||
return self._data.popleft()
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
|
||||
list of tuples). An empty sequence is returned when no more rows are available.
|
||||
|
||||
The number of rows to fetch per call is specified by the parameter. If it is not given, the
|
||||
cursor's arraysize determines the number of rows to be fetched. The method should try to
|
||||
fetch as many rows as indicated by the size parameter. If this is not possible due to the
|
||||
specified number of rows not being available, fewer rows may be returned.
|
||||
|
||||
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
||||
produce any result set or no call was issued yet.
|
||||
"""
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
result = []
|
||||
for _ in xrange(size):
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
break
|
||||
else:
|
||||
result.append(one)
|
||||
return result
|
||||
|
||||
def fetchall(self):
|
||||
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
|
||||
(e.g. a list of tuples).
|
||||
|
||||
An Error (or subclass) exception is raised if the previous call to .execute*() did not
|
||||
produce any result set or no call was issued yet.
|
||||
"""
|
||||
result = []
|
||||
while True:
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
break
|
||||
else:
|
||||
result.append(one)
|
||||
return result
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
"""This read/write attribute specifies the number of rows to fetch at a time with
|
||||
``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time.
|
||||
|
||||
In our current implementation this parameter has no effect on actual fetching.
|
||||
"""
|
||||
return self._arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._arraysize = value
|
||||
|
||||
def setinputsizes(self, sizes):
|
||||
"""Does nothing for Presto"""
|
||||
pass
|
||||
|
||||
def setoutputsize(self, size, column=None):
|
||||
"""Does nothing for Presto"""
|
||||
pass
|
||||
|
||||
#
|
||||
# Optional DB API Extensions
|
||||
#
|
||||
|
||||
@property
|
||||
def rownumber(self):
|
||||
"""This read-only attribute should provide the current 0-based index of the cursor in the
|
||||
result set.
|
||||
|
||||
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
|
||||
operation will fetch the row indexed by .rownumber in that sequence.
|
||||
"""
|
||||
return self._rownumber
|
||||
|
||||
def next(self):
|
||||
"""Return the next row from the currently executing SQL statement using the same semantics
|
||||
as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted.
|
||||
"""
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
raise StopIteration
|
||||
else:
|
||||
return one
|
||||
|
||||
def __iter__(self):
|
||||
"""Return self to make cursors compatible to the iteration protocol."""
|
||||
return self
|
||||
|
||||
#
|
||||
# Exceptions
|
||||
#
|
||||
|
||||
|
||||
class Error(exceptions.StandardError):
|
||||
"""Exception that is the base class of all other error exceptions.
|
||||
|
||||
You can use this to catch all errors with one single except statement.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Warning(exceptions.StandardError):
|
||||
"""Exception raised for important warnings like data truncations while inserting, etc."""
|
||||
pass
|
||||
|
||||
|
||||
class InterfaceError(Error):
|
||||
"""Exception raised for errors that are related to the database interface rather than the
|
||||
database itself.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(Error):
|
||||
"""Exception raised for errors that are related to the database."""
|
||||
pass
|
||||
|
||||
|
||||
class InternalError(DatabaseError):
|
||||
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
|
||||
anymore, the transaction is out of sync, etc."""
|
||||
pass
|
||||
|
||||
|
||||
class OperationalError(DatabaseError):
|
||||
"""Exception raised for errors that are related to the database's operation and not necessarily
|
||||
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
|
||||
is not found, a transaction could not be processed, a memory allocation error occurred during
|
||||
processing, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProgrammingError(DatabaseError):
|
||||
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
|
||||
in the SQL statement, wrong number of parameters specified, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DataError(DatabaseError):
|
||||
"""Exception raised for errors that are due to problems with the processed data like division by
|
||||
zero, numeric value out of range, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NotSupportedError(DatabaseError):
|
||||
"""Exception raised in case a method or database API was used which is not supported by the
|
||||
database, e.g. requesting a .rollback() on a connection that does not support transaction or
|
||||
has transactions turned off.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# Type Objects and Constructors
|
||||
#
|
||||
|
||||
|
||||
class DBAPITypeObject(object):
|
||||
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
|
||||
def __init__(self, *values):
|
||||
self.values = values
|
||||
|
||||
def __cmp__(self, other):
|
||||
if other in self.values:
|
||||
return 0
|
||||
if other < self.values:
|
||||
return 1
|
||||
else:
|
||||
return -1
|
||||
|
||||
# See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java
|
||||
FIXED_INT_64 = DBAPITypeObject(['bigint'])
|
||||
VARIABLE_BINARY = DBAPITypeObject(['varchar'])
|
||||
DOUBLE = DBAPITypeObject(['double'])
|
||||
BOOLEAN = DBAPITypeObject(['boolean'])
|
||||
|
||||
|
||||
#
|
||||
# Private utilities
|
||||
#
|
||||
def _escape_args(parameters):
|
||||
if isinstance(parameters, dict):
|
||||
return {k: _escape_item(v) for k, v in parameters.iteritems()}
|
||||
elif isinstance(parameters, (list, tuple)):
|
||||
return tuple(_escape_item(x) for x in parameters)
|
||||
else:
|
||||
raise ProgrammingError("Unsupported param format: {}".format(parameters))
|
||||
|
||||
|
||||
def _escape_item(item):
|
||||
if isinstance(item, (int, long, float)):
|
||||
return item
|
||||
elif isinstance(item, basestring):
|
||||
# TODO is this good enough?
|
||||
return "'{}'".format(item.replace("'", "''"))
|
||||
else:
|
||||
raise ProgrammingError("Unsupported object {}".format(item))
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
Shared DBAPI test cases
|
||||
"""
|
||||
from pyhive import exc
|
||||
import abc
|
||||
import contextlib
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
|
||||
def with_cursor(fn):
|
||||
"""Pass a cursor to the given function and handle cleanup.
|
||||
|
||||
The cursor is taken from ``self.connect()``.
|
||||
"""
|
||||
@functools.wraps(fn)
|
||||
def wrapped_fn(self, *args, **kwargs):
|
||||
with contextlib.closing(self.connect()) as connection:
|
||||
with contextlib.closing(connection.cursor()) as cursor:
|
||||
fn(self, cursor, *args, **kwargs)
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
class DBAPITestCase(unittest.TestCase):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
__test__ = False
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@with_cursor
|
||||
def test_fetchone(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row')
|
||||
self.assertEqual(cursor.rownumber, 0)
|
||||
self.assertEqual(cursor.fetchone(), [1])
|
||||
self.assertEqual(cursor.rownumber, 1)
|
||||
self.assertIsNone(cursor.fetchone())
|
||||
|
||||
@with_cursor
|
||||
def test_fetchall(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row')
|
||||
self.assertEqual(cursor.fetchall(), [[1]])
|
||||
cursor.execute('SELECT a FROM many_rows ORDER BY a')
|
||||
self.assertEqual(cursor.fetchall(), [[i] for i in xrange(10000)])
|
||||
|
||||
@with_cursor
|
||||
def test_iterator(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row')
|
||||
self.assertEqual(list(cursor), [[1]])
|
||||
self.assertRaises(StopIteration, cursor.next)
|
||||
|
||||
@with_cursor
|
||||
def test_description_initial(self, cursor):
|
||||
self.assertIsNone(cursor.description)
|
||||
|
||||
@with_cursor
|
||||
def test_description_failed(self, cursor):
|
||||
try:
|
||||
cursor.execute('blah_blah')
|
||||
except exc.DatabaseError:
|
||||
pass
|
||||
self.assertIsNone(cursor.description)
|
||||
|
||||
@with_cursor
|
||||
def test_bad_query(self, cursor):
|
||||
def run():
|
||||
cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist')
|
||||
cursor.fetchone()
|
||||
self.assertRaises(exc.DatabaseError, run)
|
||||
|
||||
@with_cursor
|
||||
def test_concurrent_execution(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row')
|
||||
self.assertRaises(exc.ProgrammingError,
|
||||
lambda: cursor.execute('SELECT * FROM one_row'))
|
||||
|
||||
@with_cursor
|
||||
def test_executemany(self, cursor):
|
||||
for length in 1, 2:
|
||||
cursor.executemany(
|
||||
'SELECT %(x)d FROM one_row',
|
||||
[{'x': i} for i in xrange(1, length + 1)]
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[length]])
|
||||
|
||||
@with_cursor
|
||||
def test_executemany_none(self, cursor):
|
||||
cursor.executemany('should_never_get_used', [])
|
||||
self.assertIsNone(cursor.description)
|
||||
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
|
||||
|
||||
@with_cursor
|
||||
def test_fetchone_no_data(self, cursor):
|
||||
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
|
||||
|
||||
@with_cursor
|
||||
def test_fetchmany(self, cursor):
|
||||
cursor.execute('SELECT * FROM many_rows LIMIT 15')
|
||||
self.assertEqual(cursor.fetchmany(0), [])
|
||||
self.assertEqual(len(cursor.fetchmany(10)), 10)
|
||||
self.assertEqual(len(cursor.fetchmany(10)), 5)
|
||||
|
||||
@with_cursor
|
||||
def test_arraysize(self, cursor):
|
||||
cursor.arraysize = 5
|
||||
cursor.execute('SELECT * FROM many_rows LIMIT 20')
|
||||
self.assertEqual(len(cursor.fetchmany()), 5)
|
||||
|
||||
@with_cursor
|
||||
def test_polling_loop(self, cursor):
|
||||
"""Try to trigger the polling logic in fetchone()"""
|
||||
cursor._poll_interval = 0
|
||||
cursor.execute('SELECT COUNT(*) FROM many_rows')
|
||||
self.assertEqual(cursor.fetchone(), [10000])
|
||||
|
||||
@with_cursor
|
||||
def test_no_params(self, cursor):
|
||||
cursor.execute("SELECT '%(x)s' FROM one_row")
|
||||
self.assertEqual(cursor.fetchall(), [['%(x)s']])
|
||||
|
||||
def test_escape(self):
|
||||
"""Verify that funny characters can be escaped as strings and SELECTed back"""
|
||||
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t '''
|
||||
self.run_escape_case(bad_str)
|
||||
|
||||
@with_cursor
|
||||
def run_escape_case(self, cursor, bad_str):
|
||||
cursor.execute(
|
||||
'SELECT %d, %s FROM one_row',
|
||||
(1, bad_str)
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[1, bad_str]])
|
||||
cursor.execute(
|
||||
'SELECT %(a)d, %(b)s FROM one_row',
|
||||
{'a': 1, 'b': bad_str}
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[1, bad_str]])
|
||||
|
||||
@with_cursor
|
||||
def test_invalid_params(self, cursor):
|
||||
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi'))
|
||||
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [{}]))
|
||||
|
||||
def test_open_close(self):
|
||||
with contextlib.closing(self.connect()):
|
||||
pass
|
||||
with contextlib.closing(self.connect()) as connection:
|
||||
with contextlib.closing(connection.cursor()):
|
||||
pass
|
|
@ -0,0 +1,69 @@
|
|||
"""Hive integration tests.
|
||||
|
||||
These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
|
||||
They also require a tables created by make_test_tables.sh.
|
||||
"""
|
||||
from TCLIService import ttypes
|
||||
from pyhive import exc
|
||||
from pyhive import hive
|
||||
from pyhive.tests.dbapi_test_case import DBAPITestCase
|
||||
import mock
|
||||
import contextlib
|
||||
from pyhive.tests.dbapi_test_case import with_cursor
|
||||
|
||||
_HOST = 'localhost'
|
||||
|
||||
|
||||
class TestHive(DBAPITestCase):
|
||||
__test__ = True
|
||||
|
||||
def connect(self):
|
||||
return hive.connect(host=_HOST, username='hadoop')
|
||||
|
||||
@with_cursor
|
||||
def test_description(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row')
|
||||
desc = [('number_of_rows', 'INT_TYPE', None, None, None, None, True)]
|
||||
self.assertEqual(cursor.description, desc)
|
||||
self.assertEqual(cursor.description, desc)
|
||||
|
||||
@with_cursor
|
||||
def test_complex(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row_complex')
|
||||
self.assertEqual(cursor.description, [
|
||||
('a', 'STRING_TYPE', None, None, None, None, True),
|
||||
('b', 'STRING_TYPE', None, None, None, None, True),
|
||||
])
|
||||
self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']])
|
||||
|
||||
def test_noops(self):
|
||||
"""The DB-API specification requires that certain actions exist, even though they might not
|
||||
be applicable."""
|
||||
# Wohoo inflating coverage stats!
|
||||
with contextlib.closing(self.connect()) as connection:
|
||||
with contextlib.closing(connection.cursor()) as cursor:
|
||||
self.assertEqual(cursor.rowcount, -1)
|
||||
cursor.setinputsizes([])
|
||||
cursor.setoutputsize(1, 'blah')
|
||||
connection.commit()
|
||||
|
||||
@mock.patch('TCLIService.TCLIService.Client.OpenSession')
|
||||
def test_open_failed(self, open_session):
|
||||
open_session.return_value.serverProtocolVersion = \
|
||||
ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
|
||||
self.assertRaises(exc.OperationalError, self.connect)
|
||||
|
||||
def test_escape(self):
|
||||
# Hive thrift translates newlines into multiple rows. WTF.
|
||||
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
|
||||
self.run_escape_case(bad_str)
|
||||
|
||||
def test_newlines(self):
|
||||
"""Verify that newlines are passed through in a way that doesn't fail parsing"""
|
||||
# Hive thrift translates newlines into multiple rows. WTF.
|
||||
cursor = self.connect().cursor()
|
||||
cursor.execute(
|
||||
'SELECT %s FROM one_row',
|
||||
(' \r\n \r \n ',)
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[' '], [' '], [' '], [' ']])
|
|
@ -1,137 +1,50 @@
|
|||
"""Presto integration tests.
|
||||
|
||||
These rely on having a Presto+Hadoop cluster set up. They also require a table called one_row.
|
||||
These rely on having a Presto+Hadoop cluster set up.
|
||||
They also require a tables created by make_test_tables.sh.
|
||||
"""
|
||||
from pyhive import exc
|
||||
from pyhive import presto
|
||||
from pyhive.tests.dbapi_test_case import DBAPITestCase
|
||||
from pyhive.tests.dbapi_test_case import with_cursor
|
||||
import mock
|
||||
import unittest
|
||||
|
||||
_HOST = 'localhost'
|
||||
_ONE_ROW_TABLE_NAME = 'one_row'
|
||||
_BIG_TABLE_NAME = 'user'
|
||||
|
||||
|
||||
class Testpresto(unittest.TestCase):
|
||||
def test_fetchone(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select 1 from {}'.format(_ONE_ROW_TABLE_NAME))
|
||||
self.assertEqual(cursor.rownumber, 0)
|
||||
self.assertEqual(cursor.fetchone(), [1])
|
||||
self.assertEqual(cursor.rownumber, 1)
|
||||
self.assertIsNone(cursor.fetchone())
|
||||
class TestPresto(DBAPITestCase):
|
||||
__test__ = True
|
||||
|
||||
def test_fetchall(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
|
||||
self.assertEqual(cursor.fetchall(), [[1]])
|
||||
|
||||
def test_iterator(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
|
||||
self.assertEqual(list(cursor), [[1]])
|
||||
self.assertRaises(StopIteration, lambda: cursor.next())
|
||||
def connect(self):
|
||||
return presto.connect(host=_HOST)
|
||||
|
||||
def test_description(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select 1 as foobar from {}'.format(_ONE_ROW_TABLE_NAME))
|
||||
cursor = self.connect().cursor()
|
||||
cursor.execute('SELECT 1 AS foobar FROM one_row')
|
||||
self.assertEqual(cursor.description, [('foobar', 'bigint', None, None, None, None, True)])
|
||||
|
||||
def test_description_initial(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
self.assertIsNone(cursor.description)
|
||||
|
||||
def test_description_failed(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
try:
|
||||
cursor.execute('blah_blah')
|
||||
except presto.DatabaseError:
|
||||
pass
|
||||
self.assertIsNone(cursor.description)
|
||||
|
||||
def test_bad_query(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
|
||||
def run():
|
||||
cursor.execute('select does_not_exist from {}'.format(_ONE_ROW_TABLE_NAME))
|
||||
cursor.fetchone()
|
||||
self.assertRaises(presto.DatabaseError, run)
|
||||
|
||||
def test_concurrent_execution(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
|
||||
self.assertRaises(presto.ProgrammingError,
|
||||
lambda: cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME)))
|
||||
|
||||
def test_executemany(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.executemany(
|
||||
'select %(x)d from {}'.format(_ONE_ROW_TABLE_NAME),
|
||||
[{'x': 1}, {'x': 2}, {'x': 3}]
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[3]])
|
||||
|
||||
def test_fetchone_no_data(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
self.assertRaises(presto.ProgrammingError, lambda: cursor.fetchone())
|
||||
|
||||
def test_fetchmany(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select * from {} limit 15'.format(_BIG_TABLE_NAME))
|
||||
self.assertEqual(cursor.fetchmany(0), [])
|
||||
self.assertEqual(len(cursor.fetchmany(10)), 10)
|
||||
self.assertEqual(len(cursor.fetchmany(10)), 5)
|
||||
|
||||
def test_arraysize(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.arraysize = 5
|
||||
cursor.execute('select * from {} limit 20'.format(_BIG_TABLE_NAME))
|
||||
self.assertEqual(len(cursor.fetchmany()), 5)
|
||||
|
||||
def test_slow_query(self):
|
||||
"""Trigger the polling logic in fetchone()"""
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute('select count(*) from {}'.format(_BIG_TABLE_NAME))
|
||||
self.assertIsNotNone(cursor.fetchone())
|
||||
|
||||
def test_no_params(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute("select '%(x)s' from {}".format(_ONE_ROW_TABLE_NAME))
|
||||
self.assertEqual(cursor.fetchall(), [['%(x)s']])
|
||||
@with_cursor
|
||||
def test_complex(self, cursor):
|
||||
cursor.execute('SELECT * FROM one_row_complex')
|
||||
self.assertEqual(cursor.description, [
|
||||
('a', 'varchar', None, None, None, None, True),
|
||||
('b', 'varchar', None, None, None, None, True),
|
||||
])
|
||||
self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']])
|
||||
|
||||
def test_noops(self):
|
||||
"""The DB-API specification requires that certain actions exist, even though they might not
|
||||
be applicable."""
|
||||
# Wohoo inflating coverage stats!
|
||||
connection = presto.connect(host=_HOST)
|
||||
connection = self.connect()
|
||||
cursor = connection.cursor()
|
||||
self.assertEqual(cursor.rowcount, -1)
|
||||
cursor.setinputsizes([])
|
||||
cursor.setoutputsize(1, 'blah')
|
||||
cursor.close()
|
||||
connection.commit()
|
||||
connection.close()
|
||||
|
||||
def test_escape(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor.execute(
|
||||
"select %d, %s from {}".format(_ONE_ROW_TABLE_NAME),
|
||||
(1, "';")
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[1, "';"]])
|
||||
cursor.execute(
|
||||
"select %(a)d, %(b)s from {}".format(_ONE_ROW_TABLE_NAME),
|
||||
{'a': 1, 'b': "';"}
|
||||
)
|
||||
self.assertEqual(cursor.fetchall(), [[1, "';"]])
|
||||
|
||||
def test_invalid_params(self):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', 'hi'))
|
||||
self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', [{}]))
|
||||
|
||||
@mock.patch('requests.post')
|
||||
def test_non_200(self, post):
|
||||
cursor = presto.connect(host=_HOST).cursor()
|
||||
cursor = self.connect().cursor()
|
||||
post.return_value.status_code = 404
|
||||
self.assertRaises(presto.OperationalError,
|
||||
lambda: cursor.execute('show tables'))
|
||||
self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables'))
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
|
||||
hive -e 'DROP TABLE IF EXISTS many_rows'
|
||||
hive -e "
|
||||
CREATE TABLE many_rows (
|
||||
a INT
|
||||
) PARTITIONED BY (
|
||||
b STRING
|
||||
) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE"
|
||||
|
||||
temp_file=/tmp/pyhive_test_data_many_rows.tsv
|
||||
seq 0 9999 > $temp_file
|
||||
hive -e "LOAD DATA LOCAL INPATH '$temp_file' INTO TABLE many_rows PARTITION (b='blah')"
|
||||
rm -f $temp_file
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash -eux
|
||||
hive -e 'DROP TABLE IF EXISTS one_row'
|
||||
hive -e 'CREATE TABLE one_row (number_of_rows INT)'
|
||||
hive -e 'INSERT OVERWRITE TABLE one_row SELECT COUNT(*) + 1 FROM one_row'
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash -eux
|
||||
hive -e 'DROP TABLE IF EXISTS one_row_complex'
|
||||
hive -e 'CREATE TABLE one_row_complex (a map<INT, STRING>, b array<INT>)'
|
||||
hive -e "INSERT OVERWRITE TABLE one_row_complex SELECT map(1, 'a', 2, 'b'), array(1, 2, 3) FROM one_row"
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash -eux
|
||||
# Hive must be on the path for this script to work.
|
||||
# WARNING: drops and recreates tables called one_row, one_row_complex, and many_rows.
|
||||
|
||||
$(dirname $0)/make_one_row.sh
|
||||
$(dirname $0)/make_one_row_complex.sh
|
||||
$(dirname $0)/make_many_rows.sh
|
Загрузка…
Ссылка в новой задаче