Add Hive backend, refactor common code, fix bugs

This commit is contained in:
Jing Wang 2014-01-31 23:13:31 -08:00
Родитель e4c75c9af8
Коммит e49f70fe93
12 изменённых файлов: 813 добавлений и 365 удалений

Просмотреть файл

@ -28,5 +28,5 @@ print select([func.count('*')], from_obj=user).scalar()
Requirements Requirements
============ ============
- Presto DBAPI: Just a Presto install - Presto DBAPI: Just a Presto install
- Hive DBAPI: Thrift-generated `TCLIService` package - Hive DBAPI: HiveServer2 daemon, `TCLIService`, `thrift`, `sasl`, `thrift_sasl`
- SQLAlchemy integration: `sqlalchemy` version 0.5.8 - SQLAlchemy integration: `sqlalchemy` version 0.5

216
pyhive/common.py Normal file
Просмотреть файл

@ -0,0 +1,216 @@
"""
Package private common utilities. Do not use directly.
"""
from pyhive import exc
import abc
import collections
import time
class DBAPICursor(object):
"""Base class for some common DBAPI logic"""
__metaclass__ = abc.ABCMeta
_STATE_NONE = 0
_STATE_RUNNING = 1
_STATE_FINISHED = 2
def __init__(self, poll_interval=1):
self._poll_interval = poll_interval
self._reset_state()
def _reset_state(self):
"""Reset state about the previous query in preparation for running another query"""
# State to return as part of DBAPI
self._rownumber = 0
# Internal helper state
self._state = self._STATE_NONE
self._data = collections.deque()
self._columns = None
def _fetch_while(self, fn):
while fn():
self._fetch_more()
if fn():
time.sleep(self._poll_interval)
@abc.abstractproperty
def description(self):
raise NotImplementedError
def close(self):
"""By default, do nothing"""
pass
@abc.abstractmethod
def _fetch_more(self):
"""Get more results, append it to ``self._data``, and update ``self._state``."""
raise NotImplementedError
@property
def rowcount(self):
"""By default, return -1 to indicate that this is not supported."""
return -1
def executemany(self, operation, seq_of_parameters):
"""Prepare a database operation (query or command) and then execute it against all parameter
sequences or mappings found in the sequence seq_of_parameters.
Only the final result set is retained.
Return values are not defined.
"""
for parameters in seq_of_parameters[:-1]:
self.execute(operation, parameters)
while self._state != self._STATE_FINISHED:
self._fetch_more()
if seq_of_parameters:
self.execute(operation, seq_of_parameters[-1])
def fetchone(self):
"""Fetch the next row of a query result set, returning a single sequence, or None when no
more data is available.
An Error (or subclass) exception is raised if the previous call to execute() did not
produce any result set or no call was issued yet.
"""
if self._state == self._STATE_NONE:
raise exc.ProgrammingError('No query yet')
# Sleep until we're done or we have some data to return
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
if not self._data:
return None
else:
self._rownumber += 1
return self._data.popleft()
def fetchmany(self, size=None):
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
list of tuples). An empty sequence is returned when no more rows are available.
The number of rows to fetch per call is specified by the parameter. If it is not given, the
cursor's arraysize determines the number of rows to be fetched. The method should try to
fetch as many rows as indicated by the size parameter. If this is not possible due to the
specified number of rows not being available, fewer rows may be returned.
An Error (or subclass) exception is raised if the previous call to .execute*() did not
produce any result set or no call was issued yet.
"""
if size is None:
size = self.arraysize
result = []
for _ in xrange(size):
one = self.fetchone()
if one is None:
break
else:
result.append(one)
return result
def fetchall(self):
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
(e.g. a list of tuples).
An Error (or subclass) exception is raised if the previous call to .execute*() did not
produce any result set or no call was issued yet.
"""
result = []
while True:
one = self.fetchone()
if one is None:
break
else:
result.append(one)
return result
@property
def arraysize(self):
"""This read/write attribute specifies the number of rows to fetch at a time with
``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time.
"""
return self._arraysize
@arraysize.setter
def arraysize(self, value):
self._arraysize = value
def setinputsizes(self, sizes):
"""Does nothing by default"""
pass
def setoutputsize(self, size, column=None):
"""Does nothing by default"""
pass
#
# Optional DB API Extensions
#
@property
def rownumber(self):
"""This read-only attribute should provide the current 0-based index of the cursor in the
result set.
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
operation will fetch the row indexed by .rownumber in that sequence.
"""
return self._rownumber
def next(self):
"""Return the next row from the currently executing SQL statement using the same semantics
as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted.
"""
one = self.fetchone()
if one is None:
raise StopIteration
else:
return one
def __iter__(self):
"""Return self to make cursors compatible to the iteration protocol."""
return self
class DBAPITypeObject(object):
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
def __init__(self, *values):
self.values = values
def __cmp__(self, other):
if other in self.values:
return 0
if other < self.values:
return 1
else:
return -1
class ParamEscaper(object):
def escape_args(self, parameters):
if isinstance(parameters, dict):
return {k: self.escape_item(v) for k, v in parameters.iteritems()}
elif isinstance(parameters, (list, tuple)):
return tuple(self.escape_item(x) for x in parameters)
else:
raise exc.ProgrammingError("Unsupported param format: {}".format(parameters))
def escape_number(self, item):
return item
def escape_string(self, item):
# This is good enough when backslashes are literal, newlines are just followed, and the way
# to escape a single quote is to put two single quotes.
# (i.e. only special character is single quote)
return "'{}'".format(item.replace("'", "''"))
def escape_item(self, item):
if isinstance(item, (int, long, float)):
return self.escape_number(item)
elif isinstance(item, basestring):
return self.escape_string(item)
else:
raise exc.ProgrammingError("Unsupported object {}".format(item))

70
pyhive/exc.py Normal file
Просмотреть файл

@ -0,0 +1,70 @@
"""
Package private common utilities. Do not use directly.
"""
__all__ = [
'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError',
'ProgrammingError', 'DataError', 'NotSupportedError',
]
class Error(StandardError):
"""Exception that is the base class of all other error exceptions.
You can use this to catch all errors with one single except statement.
"""
pass
class Warning(StandardError):
"""Exception raised for important warnings like data truncations while inserting, etc."""
pass
class InterfaceError(Error):
"""Exception raised for errors that are related to the database interface rather than the
database itself.
"""
pass
class DatabaseError(Error):
"""Exception raised for errors that are related to the database."""
pass
class InternalError(DatabaseError):
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
anymore, the transaction is out of sync, etc."""
pass
class OperationalError(DatabaseError):
"""Exception raised for errors that are related to the database's operation and not necessarily
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
is not found, a transaction could not be processed, a memory allocation error occurred during
processing, etc.
"""
pass
class ProgrammingError(DatabaseError):
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
in the SQL statement, wrong number of parameters specified, etc.
"""
pass
class DataError(DatabaseError):
"""Exception raised for errors that are due to problems with the processed data like division by
zero, numeric value out of range, etc.
"""
pass
class NotSupportedError(DatabaseError):
"""Exception raised in case a method or database API was used which is not supported by the
database, e.g. requesting a .rollback() on a connection that does not support transaction or
has transactions turned off.
"""
pass

245
pyhive/hive.py Normal file
Просмотреть файл

@ -0,0 +1,245 @@
"""DB API implementation backed by HiveServer2 (Thrift API)
See http://www.python.org/dev/peps/pep-0249/
Many docstrings in this file are based on the PEP, which is in the public domain.
"""
from TCLIService import TCLIService
from TCLIService import constants
from TCLIService import ttypes
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DBAPI
from pyhive.exc import *
import getpass
import logging
import sasl
import sys
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
import thrift_sasl
# PEP 249 module globals
apilevel = '2.0'
threadsafety = 2 # Threads may share the module and connections.
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
_logger = logging.getLogger(__name__)
class HiveParamEscaper(common.ParamEscaper):
def escape_string(self, item):
# backslashes and single quotes need to be escaped
# TODO verify against parser
return "'{}'".format(item
.replace('\\', '\\\\')
.replace("'", "\\'")
.replace('\r', '\\r')
.replace('\n', '\\n')
.replace('\t', '\\t')
)
_escaper = HiveParamEscaper()
def connect(**kwargs):
"""Constructor for creating a connection to the database. See class Connection for arguments.
Returns a Connection object.
"""
return Connection(**kwargs)
class Connection(object):
"""Wraps a Thrift session"""
def __init__(self, host, port=10000, username=None, configuration=None):
socket = thrift.transport.TSocket.TSocket(host, port)
username = username or getpass.getuser()
configuration = configuration or {}
def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('username', username)
# Password doesn't matter in PLAIN mode, just needs to be nonempty.
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client
# PLAIN corresponds to hive.server2.authentication=NONE in hive-site.xml
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, 'PLAIN', socket)
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
self._client = TCLIService.Client(protocol)
try:
self._transport.open()
open_session_req = ttypes.TOpenSessionReq(
client_protocol=ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1,
configuration=configuration,
)
response = self._client.OpenSession(open_session_req)
_check_status(response)
assert(response.sessionHandle is not None), "Expected a session from OpenSession"
self._sessionHandle = response.sessionHandle
assert(response.serverProtocolVersion == ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1), \
"Unable to handle protocol version {}".format(response.serverProtocolVersion)
except:
self._transport.close()
raise
def close(self):
"""Close the underlying session and Thrift transport"""
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
response = self._client.CloseSession(req)
self._transport.close()
_check_status(response)
def commit(self):
"""Hive does not support transactions"""
pass
def cursor(self):
"""Return a new Cursor object using the connection."""
return Cursor(self)
@property
def client(self):
return self._client
@property
def sessionHandle(self):
return self._sessionHandle
class Cursor(common.DBAPICursor):
"""These objects represent a database cursor, which is used to manage the context of a fetch
operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections.
"""
def __init__(self, connection):
super(Cursor, self).__init__()
self._connection = connection
def _reset_state(self):
"""Reset state about the previous query in preparation for running another query"""
super(Cursor, self)._reset_state()
self._description = None
self._operationHandle = None
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
Each of these sequences contains information describing one result column:
- name
- type_code
- display_size (None in current implementation)
- internal_size (None in current implementation)
- precision (None in current implementation)
- scale (None in current implementation)
- null_ok (always True in current implementation)
The type_code can be interpreted by comparing it to the Type Objects specified in the
section below.
"""
if self._operationHandle is None:
return None
if self._description is None:
req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
response = self._connection.client.GetResultSetMetadata(req)
_check_status(response)
columns = response.schema.columns
self._description = []
for col in columns:
primary_type_entry = col.typeDesc.types[0]
if primary_type_entry.primitiveEntry is None:
# All fancy stuff maps to string
type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
else:
type_id = primary_type_entry.primitiveEntry.type
type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
self._description.append((col.columnName, type_code, None, None, None, None, True))
return self._description
def close(self):
"""Close the operation handle"""
if self._operationHandle is not None:
request = ttypes.TCloseOperationReq(self._operationHandle)
response = self._connection.client.CloseOperation(request)
_check_status(response)
def execute(self, operation, parameters=None):
"""Prepare and execute a database operation (query or command).
Return values are not defined.
"""
if self._state == self._STATE_RUNNING:
raise ProgrammingError("Already running a query")
# Prepare statement
if parameters is None:
sql = operation
else:
sql = operation % _escaper.escape_args(parameters)
self._reset_state()
self._state = self._STATE_RUNNING
_logger.debug("Query: %s", sql)
req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, sql)
response = self._connection.client.ExecuteStatement(req)
_check_status(response)
self._operationHandle = response.operationHandle
def _fetch_more(self):
"""Send another TFetchResultsReq and update state"""
assert(self._state == self._STATE_RUNNING)
req = ttypes.TFetchResultsReq(
operationHandle=self._operationHandle,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
maxRows=1000,
)
response = self._connection.client.FetchResults(req)
_check_status(response)
# response.hasMoreRows seems to always be False, so we instead check the number of rows
#if not response.hasMoreRows:
if not response.results.rows:
self._state = self._STATE_FINISHED
for row in response.results.rows:
self._data.append([_unwrap_col_val(val) for val in row.colVals])
#
# Type Objects and Constructors
#
for type_id in constants.PRIMITIVE_TYPES:
name = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
setattr(sys.modules[__name__], name, DBAPITypeObject([name]))
#
# Private utilities
#
def _unwrap_col_val(val):
"""Return the raw value from a TColumnValue instance."""
for _, _, attr, _, _ in filter(None, ttypes.TColumnValue.thrift_spec):
val_obj = getattr(val, attr)
if val_obj:
return val_obj.value
raise DataError("Got empty column value {}".format(val))
def _check_status(response):
"""Raise an OperationalError if the status is not success"""
if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
raise OperationalError(response)

Просмотреть файл

@ -5,12 +5,13 @@ See http://www.python.org/dev/peps/pep-0249/
Many docstrings in this file are based on the PEP, which is in the public domain. Many docstrings in this file are based on the PEP, which is in the public domain.
""" """
import collections from pyhive import common
import exceptions from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DBAPI
from pyhive.exc import *
import getpass import getpass
import logging import logging
import requests import requests
import time
import urlparse import urlparse
@ -20,6 +21,7 @@ threadsafety = 2 # Threads may share the module and connections.
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_escaper = common.ParamEscaper()
def connect(**kwargs): def connect(**kwargs):
@ -53,16 +55,13 @@ class Connection(object):
return Cursor(**self._params) return Cursor(**self._params)
class Cursor(object): class Cursor(common.DBAPICursor):
"""These objects represent a database cursor, which is used to manage the context of a fetch """These objects represent a database cursor, which is used to manage the context of a fetch
operation. operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections. visible by other cursors or connections.
""" """
_STATE_NONE = 0
_STATE_RUNNING = 1
_STATE_FINISHED = 2
def __init__(self, host, port='8080', user=None, catalog='hive', schema='default', def __init__(self, host, port='8080', user=None, catalog='hive', schema='default',
poll_interval=1, source='pyhive'): poll_interval=1, source='pyhive'):
@ -76,6 +75,7 @@ class Cursor(object):
update, defaults to a second update, defaults to a second
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page) :param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
""" """
super(Cursor, self).__init__(poll_interval)
# Config # Config
self._host = host self._host = host
self._port = port self._port = port
@ -90,21 +90,10 @@ class Cursor(object):
def _reset_state(self): def _reset_state(self):
"""Reset state about the previous query in preparation for running another query""" """Reset state about the previous query in preparation for running another query"""
# State to return as part of DBAPI super(Cursor, self)._reset_state()
self._rownumber = 0
# Internal helper state
self._state = self._STATE_NONE
self._nextUri = None self._nextUri = None
self._data = collections.deque()
self._columns = None self._columns = None
def _fetch_while(self, fn):
while fn():
self._fetch_more()
if fn():
time.sleep(self._poll_interval)
@property @property
def description(self): def description(self):
"""This read-only attribute is a sequence of 7-item sequences. """This read-only attribute is a sequence of 7-item sequences.
@ -135,15 +124,6 @@ class Cursor(object):
for col in self._columns for col in self._columns
] ]
@property
def rowcount(self):
"""Presto does not support rowcount"""
return -1
def close(self):
"""Presto does not have anything to close"""
pass
def execute(self, operation, parameters=None): def execute(self, operation, parameters=None):
"""Prepare and execute a database operation (query or command). """Prepare and execute a database operation (query or command).
@ -163,7 +143,7 @@ class Cursor(object):
if parameters is None: if parameters is None:
sql = operation sql = operation
else: else:
sql = operation % _escape_args(parameters) sql = operation % _escaper.escape_args(parameters)
self._reset_state() self._reset_state()
@ -176,7 +156,7 @@ class Cursor(object):
self._process_response(response) self._process_response(response)
def _fetch_more(self): def _fetch_more(self):
"""Fetch the next URI and udpate state""" """Fetch the next URI and update state"""
self._process_response(requests.get(self._nextUri)) self._process_response(requests.get(self._nextUri))
def _process_response(self, response): def _process_response(self, response):
@ -199,238 +179,14 @@ class Cursor(object):
assert not self._nextUri, 'Should not have nextUri if failed' assert not self._nextUri, 'Should not have nextUri if failed'
raise DatabaseError(response_json['error']) raise DatabaseError(response_json['error'])
def executemany(self, operation, seq_of_parameters):
"""Prepare a database operation (query or command) and then execute it against all parameter
sequences or mappings found in the sequence seq_of_parameters.
Only the final result set is retained.
Return values are not defined.
"""
for parameters in seq_of_parameters[:-1]:
self.execute(operation, parameters)
while self._state != self._STATE_FINISHED:
self._fetch_more()
self.execute(operation, seq_of_parameters[-1])
def fetchone(self):
"""Fetch the next row of a query result set, returning a single sequence, or None when no
more data is available.
An Error (or subclass) exception is raised if the previous call to execute() did not
produce any result set or no call was issued yet.
"""
if self._state == self._STATE_NONE:
raise ProgrammingError('No query yet')
# Note: all Presto statements produce a result set
# The CREATE TABLE statement produces a single bigint called 'rows'
# Sleep until we're done or we have some data to return
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
if not self._data:
return None
else:
self._rownumber += 1
return self._data.popleft()
def fetchmany(self, size=None):
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
list of tuples). An empty sequence is returned when no more rows are available.
The number of rows to fetch per call is specified by the parameter. If it is not given, the
cursor's arraysize determines the number of rows to be fetched. The method should try to
fetch as many rows as indicated by the size parameter. If this is not possible due to the
specified number of rows not being available, fewer rows may be returned.
An Error (or subclass) exception is raised if the previous call to .execute*() did not
produce any result set or no call was issued yet.
"""
if size is None:
size = self.arraysize
result = []
for _ in xrange(size):
one = self.fetchone()
if one is None:
break
else:
result.append(one)
return result
def fetchall(self):
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
(e.g. a list of tuples).
An Error (or subclass) exception is raised if the previous call to .execute*() did not
produce any result set or no call was issued yet.
"""
result = []
while True:
one = self.fetchone()
if one is None:
break
else:
result.append(one)
return result
@property
def arraysize(self):
"""This read/write attribute specifies the number of rows to fetch at a time with
``.fetchmany()``. It defaults to 1 meaning to fetch a single row at a time.
In our current implementation this parameter has no effect on actual fetching.
"""
return self._arraysize
@arraysize.setter
def arraysize(self, value):
self._arraysize = value
def setinputsizes(self, sizes):
"""Does nothing for Presto"""
pass
def setoutputsize(self, size, column=None):
"""Does nothing for Presto"""
pass
#
# Optional DB API Extensions
#
@property
def rownumber(self):
"""This read-only attribute should provide the current 0-based index of the cursor in the
result set.
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
operation will fetch the row indexed by .rownumber in that sequence.
"""
return self._rownumber
def next(self):
"""Return the next row from the currently executing SQL statement using the same semantics
as ``.fetchone()``. A StopIteration exception is raised when the result set is exhausted.
"""
one = self.fetchone()
if one is None:
raise StopIteration
else:
return one
def __iter__(self):
"""Return self to make cursors compatible to the iteration protocol."""
return self
#
# Exceptions
#
class Error(exceptions.StandardError):
"""Exception that is the base class of all other error exceptions.
You can use this to catch all errors with one single except statement.
"""
pass
class Warning(exceptions.StandardError):
"""Exception raised for important warnings like data truncations while inserting, etc."""
pass
class InterfaceError(Error):
"""Exception raised for errors that are related to the database interface rather than the
database itself.
"""
pass
class DatabaseError(Error):
"""Exception raised for errors that are related to the database."""
pass
class InternalError(DatabaseError):
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
anymore, the transaction is out of sync, etc."""
pass
class OperationalError(DatabaseError):
"""Exception raised for errors that are related to the database's operation and not necessarily
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
is not found, a transaction could not be processed, a memory allocation error occurred during
processing, etc.
"""
pass
class ProgrammingError(DatabaseError):
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
in the SQL statement, wrong number of parameters specified, etc.
"""
pass
class DataError(DatabaseError):
"""Exception raised for errors that are due to problems with the processed data like division by
zero, numeric value out of range, etc.
"""
pass
class NotSupportedError(DatabaseError):
"""Exception raised in case a method or database API was used which is not supported by the
database, e.g. requesting a .rollback() on a connection that does not support transaction or
has transactions turned off.
"""
pass
# #
# Type Objects and Constructors # Type Objects and Constructors
# #
class DBAPITypeObject(object):
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
def __init__(self, *values):
self.values = values
def __cmp__(self, other):
if other in self.values:
return 0
if other < self.values:
return 1
else:
return -1
# See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java # See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java
FIXED_INT_64 = DBAPITypeObject(['bigint']) FIXED_INT_64 = DBAPITypeObject(['bigint'])
VARIABLE_BINARY = DBAPITypeObject(['varchar']) VARIABLE_BINARY = DBAPITypeObject(['varchar'])
DOUBLE = DBAPITypeObject(['double']) DOUBLE = DBAPITypeObject(['double'])
BOOLEAN = DBAPITypeObject(['boolean']) BOOLEAN = DBAPITypeObject(['boolean'])
#
# Private utilities
#
def _escape_args(parameters):
if isinstance(parameters, dict):
return {k: _escape_item(v) for k, v in parameters.iteritems()}
elif isinstance(parameters, (list, tuple)):
return tuple(_escape_item(x) for x in parameters)
else:
raise ProgrammingError("Unsupported param format: {}".format(parameters))
def _escape_item(item):
if isinstance(item, (int, long, float)):
return item
elif isinstance(item, basestring):
# TODO is this good enough?
return "'{}'".format(item.replace("'", "''"))
else:
raise ProgrammingError("Unsupported object {}".format(item))

Просмотреть файл

@ -0,0 +1,150 @@
"""
Shared DBAPI test cases
"""
from pyhive import exc
import abc
import contextlib
import functools
import unittest
def with_cursor(fn):
"""Pass a cursor to the given function and handle cleanup.
The cursor is taken from ``self.connect()``.
"""
@functools.wraps(fn)
def wrapped_fn(self, *args, **kwargs):
with contextlib.closing(self.connect()) as connection:
with contextlib.closing(connection.cursor()) as cursor:
fn(self, cursor, *args, **kwargs)
return wrapped_fn
class DBAPITestCase(unittest.TestCase):
__metaclass__ = abc.ABCMeta
__test__ = False
@abc.abstractmethod
def connect(self):
raise NotImplementedError
@with_cursor
def test_fetchone(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.rownumber, 0)
self.assertEqual(cursor.fetchone(), [1])
self.assertEqual(cursor.rownumber, 1)
self.assertIsNone(cursor.fetchone())
@with_cursor
def test_fetchall(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [[1]])
cursor.execute('SELECT a FROM many_rows ORDER BY a')
self.assertEqual(cursor.fetchall(), [[i] for i in xrange(10000)])
@with_cursor
def test_iterator(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(list(cursor), [[1]])
self.assertRaises(StopIteration, cursor.next)
@with_cursor
def test_description_initial(self, cursor):
self.assertIsNone(cursor.description)
@with_cursor
def test_description_failed(self, cursor):
try:
cursor.execute('blah_blah')
except exc.DatabaseError:
pass
self.assertIsNone(cursor.description)
@with_cursor
def test_bad_query(self, cursor):
def run():
cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist')
cursor.fetchone()
self.assertRaises(exc.DatabaseError, run)
@with_cursor
def test_concurrent_execution(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertRaises(exc.ProgrammingError,
lambda: cursor.execute('SELECT * FROM one_row'))
@with_cursor
def test_executemany(self, cursor):
for length in 1, 2:
cursor.executemany(
'SELECT %(x)d FROM one_row',
[{'x': i} for i in xrange(1, length + 1)]
)
self.assertEqual(cursor.fetchall(), [[length]])
@with_cursor
def test_executemany_none(self, cursor):
cursor.executemany('should_never_get_used', [])
self.assertIsNone(cursor.description)
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
@with_cursor
def test_fetchone_no_data(self, cursor):
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
@with_cursor
def test_fetchmany(self, cursor):
cursor.execute('SELECT * FROM many_rows LIMIT 15')
self.assertEqual(cursor.fetchmany(0), [])
self.assertEqual(len(cursor.fetchmany(10)), 10)
self.assertEqual(len(cursor.fetchmany(10)), 5)
@with_cursor
def test_arraysize(self, cursor):
cursor.arraysize = 5
cursor.execute('SELECT * FROM many_rows LIMIT 20')
self.assertEqual(len(cursor.fetchmany()), 5)
@with_cursor
def test_polling_loop(self, cursor):
"""Try to trigger the polling logic in fetchone()"""
cursor._poll_interval = 0
cursor.execute('SELECT COUNT(*) FROM many_rows')
self.assertEqual(cursor.fetchone(), [10000])
@with_cursor
def test_no_params(self, cursor):
cursor.execute("SELECT '%(x)s' FROM one_row")
self.assertEqual(cursor.fetchall(), [['%(x)s']])
def test_escape(self):
"""Verify that funny characters can be escaped as strings and SELECTed back"""
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t '''
self.run_escape_case(bad_str)
@with_cursor
def run_escape_case(self, cursor, bad_str):
cursor.execute(
'SELECT %d, %s FROM one_row',
(1, bad_str)
)
self.assertEqual(cursor.fetchall(), [[1, bad_str]])
cursor.execute(
'SELECT %(a)d, %(b)s FROM one_row',
{'a': 1, 'b': bad_str}
)
self.assertEqual(cursor.fetchall(), [[1, bad_str]])
@with_cursor
def test_invalid_params(self, cursor):
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi'))
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [{}]))
def test_open_close(self):
with contextlib.closing(self.connect()):
pass
with contextlib.closing(self.connect()) as connection:
with contextlib.closing(connection.cursor()):
pass

69
pyhive/tests/test_hive.py Normal file
Просмотреть файл

@ -0,0 +1,69 @@
"""Hive integration tests.
These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
They also require a tables created by make_test_tables.sh.
"""
from TCLIService import ttypes
from pyhive import exc
from pyhive import hive
from pyhive.tests.dbapi_test_case import DBAPITestCase
import mock
import contextlib
from pyhive.tests.dbapi_test_case import with_cursor
_HOST = 'localhost'
class TestHive(DBAPITestCase):
__test__ = True
def connect(self):
return hive.connect(host=_HOST, username='hadoop')
@with_cursor
def test_description(self, cursor):
cursor.execute('SELECT * FROM one_row')
desc = [('number_of_rows', 'INT_TYPE', None, None, None, None, True)]
self.assertEqual(cursor.description, desc)
self.assertEqual(cursor.description, desc)
@with_cursor
def test_complex(self, cursor):
cursor.execute('SELECT * FROM one_row_complex')
self.assertEqual(cursor.description, [
('a', 'STRING_TYPE', None, None, None, None, True),
('b', 'STRING_TYPE', None, None, None, None, True),
])
self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']])
def test_noops(self):
"""The DB-API specification requires that certain actions exist, even though they might not
be applicable."""
# Wohoo inflating coverage stats!
with contextlib.closing(self.connect()) as connection:
with contextlib.closing(connection.cursor()) as cursor:
self.assertEqual(cursor.rowcount, -1)
cursor.setinputsizes([])
cursor.setoutputsize(1, 'blah')
connection.commit()
@mock.patch('TCLIService.TCLIService.Client.OpenSession')
def test_open_failed(self, open_session):
open_session.return_value.serverProtocolVersion = \
ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
self.assertRaises(exc.OperationalError, self.connect)
def test_escape(self):
# Hive thrift translates newlines into multiple rows. WTF.
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
self.run_escape_case(bad_str)
def test_newlines(self):
"""Verify that newlines are passed through in a way that doesn't fail parsing"""
# Hive thrift translates newlines into multiple rows. WTF.
cursor = self.connect().cursor()
cursor.execute(
'SELECT %s FROM one_row',
(' \r\n \r \n ',)
)
self.assertEqual(cursor.fetchall(), [[' '], [' '], [' '], [' ']])

Просмотреть файл

@ -1,137 +1,50 @@
"""Presto integration tests. """Presto integration tests.
These rely on having a Presto+Hadoop cluster set up. They also require a table called one_row. These rely on having a Presto+Hadoop cluster set up.
They also require a tables created by make_test_tables.sh.
""" """
from pyhive import exc
from pyhive import presto from pyhive import presto
from pyhive.tests.dbapi_test_case import DBAPITestCase
from pyhive.tests.dbapi_test_case import with_cursor
import mock import mock
import unittest
_HOST = 'localhost' _HOST = 'localhost'
_ONE_ROW_TABLE_NAME = 'one_row'
_BIG_TABLE_NAME = 'user'
class Testpresto(unittest.TestCase): class TestPresto(DBAPITestCase):
def test_fetchone(self): __test__ = True
cursor = presto.connect(host=_HOST).cursor()
cursor.execute('select 1 from {}'.format(_ONE_ROW_TABLE_NAME))
self.assertEqual(cursor.rownumber, 0)
self.assertEqual(cursor.fetchone(), [1])
self.assertEqual(cursor.rownumber, 1)
self.assertIsNone(cursor.fetchone())
def test_fetchall(self): def connect(self):
cursor = presto.connect(host=_HOST).cursor() return presto.connect(host=_HOST)
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
self.assertEqual(cursor.fetchall(), [[1]])
def test_iterator(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
self.assertEqual(list(cursor), [[1]])
self.assertRaises(StopIteration, lambda: cursor.next())
def test_description(self): def test_description(self):
cursor = presto.connect(host=_HOST).cursor() cursor = self.connect().cursor()
cursor.execute('select 1 as foobar from {}'.format(_ONE_ROW_TABLE_NAME)) cursor.execute('SELECT 1 AS foobar FROM one_row')
self.assertEqual(cursor.description, [('foobar', 'bigint', None, None, None, None, True)]) self.assertEqual(cursor.description, [('foobar', 'bigint', None, None, None, None, True)])
def test_description_initial(self): @with_cursor
cursor = presto.connect(host=_HOST).cursor() def test_complex(self, cursor):
self.assertIsNone(cursor.description) cursor.execute('SELECT * FROM one_row_complex')
self.assertEqual(cursor.description, [
def test_description_failed(self): ('a', 'varchar', None, None, None, None, True),
cursor = presto.connect(host=_HOST).cursor() ('b', 'varchar', None, None, None, None, True),
try: ])
cursor.execute('blah_blah') self.assertEqual(cursor.fetchall(), [['{1:"a",2:"b"}', '[1,2,3]']])
except presto.DatabaseError:
pass
self.assertIsNone(cursor.description)
def test_bad_query(self):
cursor = presto.connect(host=_HOST).cursor()
def run():
cursor.execute('select does_not_exist from {}'.format(_ONE_ROW_TABLE_NAME))
cursor.fetchone()
self.assertRaises(presto.DatabaseError, run)
def test_concurrent_execution(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME))
self.assertRaises(presto.ProgrammingError,
lambda: cursor.execute('select count(*) from {}'.format(_ONE_ROW_TABLE_NAME)))
def test_executemany(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.executemany(
'select %(x)d from {}'.format(_ONE_ROW_TABLE_NAME),
[{'x': 1}, {'x': 2}, {'x': 3}]
)
self.assertEqual(cursor.fetchall(), [[3]])
def test_fetchone_no_data(self):
cursor = presto.connect(host=_HOST).cursor()
self.assertRaises(presto.ProgrammingError, lambda: cursor.fetchone())
def test_fetchmany(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.execute('select * from {} limit 15'.format(_BIG_TABLE_NAME))
self.assertEqual(cursor.fetchmany(0), [])
self.assertEqual(len(cursor.fetchmany(10)), 10)
self.assertEqual(len(cursor.fetchmany(10)), 5)
def test_arraysize(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.arraysize = 5
cursor.execute('select * from {} limit 20'.format(_BIG_TABLE_NAME))
self.assertEqual(len(cursor.fetchmany()), 5)
def test_slow_query(self):
"""Trigger the polling logic in fetchone()"""
cursor = presto.connect(host=_HOST).cursor()
cursor.execute('select count(*) from {}'.format(_BIG_TABLE_NAME))
self.assertIsNotNone(cursor.fetchone())
def test_no_params(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.execute("select '%(x)s' from {}".format(_ONE_ROW_TABLE_NAME))
self.assertEqual(cursor.fetchall(), [['%(x)s']])
def test_noops(self): def test_noops(self):
"""The DB-API specification requires that certain actions exist, even though they might not """The DB-API specification requires that certain actions exist, even though they might not
be applicable.""" be applicable."""
# Wohoo inflating coverage stats! # Wohoo inflating coverage stats!
connection = presto.connect(host=_HOST) connection = self.connect()
cursor = connection.cursor() cursor = connection.cursor()
self.assertEqual(cursor.rowcount, -1) self.assertEqual(cursor.rowcount, -1)
cursor.setinputsizes([]) cursor.setinputsizes([])
cursor.setoutputsize(1, 'blah') cursor.setoutputsize(1, 'blah')
cursor.close()
connection.commit() connection.commit()
connection.close()
def test_escape(self):
cursor = presto.connect(host=_HOST).cursor()
cursor.execute(
"select %d, %s from {}".format(_ONE_ROW_TABLE_NAME),
(1, "';")
)
self.assertEqual(cursor.fetchall(), [[1, "';"]])
cursor.execute(
"select %(a)d, %(b)s from {}".format(_ONE_ROW_TABLE_NAME),
{'a': 1, 'b': "';"}
)
self.assertEqual(cursor.fetchall(), [[1, "';"]])
def test_invalid_params(self):
cursor = presto.connect(host=_HOST).cursor()
self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', 'hi'))
self.assertRaises(presto.ProgrammingError, lambda: cursor.execute('', [{}]))
@mock.patch('requests.post') @mock.patch('requests.post')
def test_non_200(self, post): def test_non_200(self, post):
cursor = presto.connect(host=_HOST).cursor() cursor = self.connect().cursor()
post.return_value.status_code = 404 post.return_value.status_code = 404
self.assertRaises(presto.OperationalError, self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables'))
lambda: cursor.execute('show tables'))

14
scripts/make_many_rows.sh Executable file
Просмотреть файл

@ -0,0 +1,14 @@
#!/bin/bash
hive -e 'DROP TABLE IF EXISTS many_rows'
hive -e "
CREATE TABLE many_rows (
a INT
) PARTITIONED BY (
b STRING
) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE"
temp_file=/tmp/pyhive_test_data_many_rows.tsv
seq 0 9999 > $temp_file
hive -e "LOAD DATA LOCAL INPATH '$temp_file' INTO TABLE many_rows PARTITION (b='blah')"
rm -f $temp_file

4
scripts/make_one_row.sh Executable file
Просмотреть файл

@ -0,0 +1,4 @@
#!/bin/bash -eux
hive -e 'DROP TABLE IF EXISTS one_row'
hive -e 'CREATE TABLE one_row (number_of_rows INT)'
hive -e 'INSERT OVERWRITE TABLE one_row SELECT COUNT(*) + 1 FROM one_row'

Просмотреть файл

@ -0,0 +1,4 @@
#!/bin/bash -eux
hive -e 'DROP TABLE IF EXISTS one_row_complex'
hive -e 'CREATE TABLE one_row_complex (a map<INT, STRING>, b array<INT>)'
hive -e "INSERT OVERWRITE TABLE one_row_complex SELECT map(1, 'a', 2, 'b'), array(1, 2, 3) FROM one_row"

7
scripts/make_test_tables.sh Executable file
Просмотреть файл

@ -0,0 +1,7 @@
#!/bin/bash -eux
# Hive must be on the path for this script to work.
# WARNING: drops and recreates tables called one_row, one_row_complex, and many_rows.
$(dirname $0)/make_one_row.sh
$(dirname $0)/make_one_row_complex.sh
$(dirname $0)/make_many_rows.sh