Fix the tests and make SQLAlchemy 1.2.8 compatible (#219)

* (More) idiomatic code for `fetchmany`, `fetchall`

I went to look if `fetchall()` downloads th entire result in memory and found that it was coded in a very low-level C style, while Python has a perfectly good idiomatic approach for such things. So I thought I'd submit a fix :-) Take it or leave it, I won't insist.

P.S. I'm editing this from GitHub UI, so haven't run any tests. And although I'm fairly sure it should work as is, consider it formally as a demonstration :-)

* Remove unused import

* Removing unreleased thrift_sasl comment

Hello. I might be removing this prematurely, but `pip search thrift_sasl` now shows 0.3.0. Can this comment be removed?

Also, I filled out the Dropbox Contributor License.

* Add __enter__, __exit__ methods

The methods allow us to use connections and cursors with `with`s
statements.

* Use a setter for arraysize

The reason behind this is that non-integer sizes or None (or 0) do not
make sense.

Especially with None, which might be a common pattern, this should be
prevented as it raises (cryptic) errors on the Thrift size.

* Add script to generate new TCLIService files

* Add decimal/timestamp conversion

* Add the Travis pin

* Fix Flake8 violations

* Add Codecov badge
This commit is contained in:
Fokko Driesprong 2018-06-12 02:10:59 +02:00 коммит произвёл Gabriel Silk
Родитель 264cfea4d9
Коммит e25fc8440a
13 изменённых файлов: 222 добавлений и 55 удалений

Просмотреть файл

@ -7,17 +7,14 @@ matrix:
# One build pulls latest versions dynamically
- python: 3.6
env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE SQLALCHEMY=sqlalchemy
# Others use pinned versions.
- python: 3.6
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
- python: 3.5
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.0.12
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
- python: 3.4
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.0.12
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
- python: 2.7
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.0.12
# stale stuff we're still using / supporting
- python: 2.7
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==0.8.7
# exclude: python 3 against old libries
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
install:
- ./scripts/travis-install.sh
- pip install codecov

Просмотреть файл

@ -1,5 +1,6 @@
.. image:: https://travis-ci.org/dropbox/PyHive.svg?branch=master
:target: https://travis-ci.org/dropbox/PyHive
.. image:: https://img.shields.io/codecov/c/github/dropbox/PyHive.svg
======
PyHive
@ -100,9 +101,6 @@ PyHive works with
- Python 2.7 / Python 3
- For Presto: Presto install
- For Hive: `HiveServer2 <https://cwiki.apache.org/confluence/display/Hive/Setting+up+HiveServer2>`_ daemon
- For Python 3 + Hive + SASL, you currently need to install an unreleased version of ``thrift_sasl``
(``pip install git+https://github.com/cloudera/thrift_sasl``).
At the time of writing, the latest version of ``thrift_sasl`` was 0.2.1.
Changelog
=========
@ -137,3 +135,10 @@ Run the following in an environment with Hive/Presto::
WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and ``many_rows``, plus a
database called ``pyhive_test_database``.
Updating TCLIService
====================
The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the
``generate.py`` file can be used: ``python generate.py <TCLIServiceURL>``. When left blank, the
version for Hive 2.3 will be downloaded.

Просмотреть файл

@ -11,7 +11,6 @@ pytest-timeout==1.2.0
# actual dependencies: let things break if a package changes
requests>=1.0.0
sasl>=0.2.1
sqlalchemy>=0.8.7
thrift>=0.10.0
#thrift_sasl>=0.1.0
git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches

52
generate.py Normal file
Просмотреть файл

@ -0,0 +1,52 @@
"""
This file can be used to generate a new version of the TCLIService package
using a TCLIService.thrift URL.
If no URL is specified, the file for Hive 2.3 will be downloaded.
Usage:
python generate.py THRIFT_URL
or
python generate.py
"""
import shutil
import sys
from os import path
from urllib.request import urlopen
import subprocess
here = path.abspath(path.dirname(__file__))
PACKAGE = 'TCLIService'
GENERATED = 'gen-py'
HIVE_SERVER2_URL = \
'https://raw.githubusercontent.com/apache/hive/branch-2.3/service-rpc/if/TCLIService.thrift'
def save_url(url):
data = urlopen(url).read()
file_path = path.join(here, url.rsplit('/', 1)[-1])
with open(file_path, 'wb') as f:
f.write(data)
def main(hive_server2_url):
save_url(hive_server2_url)
hive_server2_path = path.join(here, hive_server2_url.rsplit('/', 1)[-1])
subprocess.call(['thrift', '-r', '--gen', 'py', hive_server2_path])
shutil.move(path.join(here, PACKAGE), path.join(here, PACKAGE + '.old'))
shutil.move(path.join(here, GENERATED, PACKAGE), path.join(here, PACKAGE))
shutil.rmtree(path.join(here, PACKAGE + '.old'))
if __name__ == '__main__':
if len(sys.argv) > 1:
url = sys.argv[1]
else:
url = HIVE_SERVER2_URL
main(url)

Просмотреть файл

@ -8,7 +8,6 @@ from __future__ import unicode_literals
from builtins import bytes
from builtins import int
from builtins import object
from builtins import range
from builtins import str
from past.builtins import basestring
from pyhive import exc
@ -16,6 +15,7 @@ import abc
import collections
import time
from future.utils import with_metaclass
from itertools import islice
class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
@ -124,14 +124,7 @@ class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
"""
if size is None:
size = self.arraysize
result = []
for _ in range(size):
one = self.fetchone()
if one is None:
break
else:
result.append(one)
return result
return list(islice(iter(self.fetchone, None), size))
def fetchall(self):
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
@ -140,14 +133,7 @@ class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
:py:meth:`execute` did not produce any result set or no call was issued yet.
"""
result = []
while True:
one = self.fetchone()
if one is None:
break
else:
result.append(one)
return result
return list(iter(self.fetchone, None))
@property
def arraysize(self):

Просмотреть файл

@ -7,6 +7,11 @@ Many docstrings in this file are based on the PEP, which is in the public domain
from __future__ import absolute_import
from __future__ import unicode_literals
import datetime
import re
from decimal import Decimal
from TCLIService import TCLIService
from TCLIService import constants
from TCLIService import ttypes
@ -31,6 +36,31 @@ paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(na
_logger = logging.getLogger(__name__)
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
def _parse_timestamp(value):
if value:
match = _TIMESTAMP_PATTERN.match(value)
if match:
if match.group(2):
format = '%Y-%m-%d %H:%M:%S.%f'
# use the pattern to truncate the value
value = match.group()
else:
format = '%Y-%m-%d %H:%M:%S'
value = datetime.datetime.strptime(value, format)
else:
raise Exception(
'Cannot convert "{}" into a datetime'.format(value))
else:
value = None
return value
TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp}
class HiveParamEscaper(common.ParamEscaper):
def escape_string(self, item):
@ -177,6 +207,14 @@ class Connection(object):
self._transport.close()
raise
def __enter__(self):
"""Transport should already be opened by __init__"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Call close"""
self.close()
def close(self):
"""Close the underlying session and Thrift transport"""
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
@ -215,7 +253,7 @@ class Cursor(common.DBAPICursor):
def __init__(self, connection, arraysize=1000):
self._operationHandle = None
super(Cursor, self).__init__()
self.arraysize = arraysize
self._arraysize = arraysize
self._connection = connection
def _reset_state(self):
@ -230,6 +268,19 @@ class Cursor(common.DBAPICursor):
finally:
self._operationHandle = None
@property
def arraysize(self):
return self._arraysize
@arraysize.setter
def arraysize(self, value):
"""Array size cannot be None, and should be an integer"""
default_arraysize = 1000
try:
self._arraysize = int(value) or default_arraysize
except TypeError:
self._arraysize = default_arraysize
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
@ -273,6 +324,12 @@ class Cursor(common.DBAPICursor):
))
return self._description
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
"""Close the operation handle"""
self._reset_state()
@ -320,8 +377,10 @@ class Cursor(common.DBAPICursor):
)
response = self._connection.client.FetchResults(req)
_check_status(response)
schema = self.description
assert not response.results.rows, 'expected data in columnar format'
columns = map(_unwrap_column, response.results.columns)
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
zip(response.results.columns, schema)]
new_data = list(zip(*columns))
self._data += new_data
# response.hasMoreRows seems to always be False, so we instead check the number of rows
@ -402,7 +461,7 @@ for type_id in constants.PRIMITIVE_TYPES:
#
def _unwrap_column(col):
def _unwrap_column(col, type_=None):
"""Return a list of raw values from a TColumn instance."""
for attr, wrapper in iteritems(col.__dict__):
if wrapper is not None:
@ -414,6 +473,9 @@ def _unwrap_column(col):
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]
return result
raise DataError("Got empty column value {}".format(col)) # pragma: no cover

Просмотреть файл

@ -8,6 +8,7 @@ which is released under the MIT license.
from __future__ import absolute_import
from __future__ import unicode_literals
import datetime
import decimal
import re
@ -24,6 +25,9 @@ from sqlalchemy.sql.compiler import SQLCompiler
from pyhive import hive
from pyhive.common import UniversalSet
from dateutil.parser import parse
from decimal import Decimal
class HiveStringTypeBase(types.TypeDecorator):
"""Translates strings returned by Thrift into something else"""
@ -40,6 +44,22 @@ class HiveDate(HiveStringTypeBase):
def process_result_value(self, value, dialect):
return processors.str_to_date(value)
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, datetime.datetime):
return value.date()
elif isinstance(value, datetime.date):
return value
elif value is not None:
return parse(value).date()
else:
return None
return process
def adapt(self, impltype, **kwargs):
return self.impl
class HiveTimestamp(HiveStringTypeBase):
"""Translates timestamp strings to datetime objects"""
@ -48,16 +68,44 @@ class HiveTimestamp(HiveStringTypeBase):
def process_result_value(self, value, dialect):
return processors.str_to_datetime(value)
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, datetime.datetime):
return value
elif value is not None:
return parse(value)
else:
return None
return process
def adapt(self, impltype, **kwargs):
return self.impl
class HiveDecimal(HiveStringTypeBase):
"""Translates strings to decimals"""
impl = types.DECIMAL
def process_result_value(self, value, dialect):
if value is None:
return None
else:
if value is not None:
return decimal.Decimal(value)
else:
return None
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, Decimal):
return value
elif value is not None:
return Decimal(value)
else:
return None
return process
def adapt(self, impltype, **kwargs):
return self.impl
class HiveIdentifierPreparer(compiler.IdentifierPreparer):
@ -195,11 +243,6 @@ class HiveDialect(default.DefaultDialect):
returns_unicode_strings = True
description_encoding = None
supports_multivalues_insert = True
dbapi_type_map = {
'DATE_TYPE': HiveDate(),
'TIMESTAMP_TYPE': HiveTimestamp(),
'DECIMAL_TYPE': HiveDecimal(),
}
type_compiler = HiveTypeCompiler
@classmethod
@ -215,7 +258,7 @@ class HiveDialect(default.DefaultDialect):
'database': url.database or 'default',
}
kwargs.update(url.query)
return ([], kwargs)
return [], kwargs
def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
@ -276,6 +319,7 @@ class HiveDialect(default.DefaultDialect):
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name))
coltype = types.NullType
result.append({
'name': col_name,
'type': coltype,

Просмотреть файл

@ -104,7 +104,7 @@ class PrestoDialect(default.DefaultDialect):
kwargs['schema'] = db_parts[1]
else:
raise ValueError("Unexpected database format {}".format(url.database))
return ([], kwargs)
return [], kwargs
def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]

Просмотреть файл

@ -64,9 +64,9 @@ class DBAPITestCase(with_metaclass(abc.ABCMeta, object)):
def test_description_failed(self, cursor):
try:
cursor.execute('blah_blah')
self.assertIsNone(cursor.description)
except exc.DatabaseError:
pass
self.assertIsNone(cursor.description)
@with_cursor
def test_bad_query(self, cursor):

Просмотреть файл

@ -8,11 +8,13 @@ from __future__ import absolute_import
from __future__ import unicode_literals
import contextlib
import datetime
import os
import socket
import subprocess
import time
import unittest
from decimal import Decimal
import mock
import sasl
@ -72,13 +74,13 @@ class TestHive(unittest.TestCase, DBAPITestCase):
0.5,
0.25,
'a string',
'1970-01-01 00:00:00.0',
datetime.datetime(1970, 1, 1, 0, 0),
b'123',
'[1,2]',
'{1:2,3:4}',
'{"a":1,"b":2}',
'{0:1}',
'0.1',
Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
@ -159,7 +161,7 @@ class TestHive(unittest.TestCase, DBAPITestCase):
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
_restart_hs2()
with contextlib.closing(hive.connect(
host=_HOST, username='existing', auth='LDAP', password='testpw')
host=_HOST, username='existing', auth='LDAP', password='testpw')
) as connection:
with contextlib.closing(connection.cursor()) as cursor:
cursor.execute('SELECT * FROM one_row')
@ -209,6 +211,7 @@ class TestHive(unittest.TestCase, DBAPITestCase):
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):

Просмотреть файл

@ -130,6 +130,7 @@ class TestPresto(unittest.TestCase, DBAPITestCase):
def fail(*args, **kwargs):
self.fail("Should not need requests.get after done polling") # pragma: no cover
with mock.patch('requests.get', fail):
self.assertEqual(cursor.fetchall(), [(1,)])

Просмотреть файл

@ -36,6 +36,25 @@ _ONE_ROW_COMPLEX_CONTENTS = [
]
# [
# ('boolean', 'boolean', ''),
# ('tinyint', 'tinyint', ''),
# ('smallint', 'smallint', ''),
# ('int', 'int', ''),
# ('bigint', 'bigint', ''),
# ('float', 'float', ''),
# ('double', 'double', ''),
# ('string', 'string', ''),
# ('timestamp', 'timestamp', ''),
# ('binary', 'binary', ''),
# ('array', 'array<int>', ''),
# ('map', 'map<int,int>', ''),
# ('struct', 'struct<a:int,b:int>', ''),
# ('union', 'uniontype<int,string>', ''),
# ('decimal', 'decimal(10,1)', '')
# ]
class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
def create_engine(self):
return create_engine('hive://localhost:10000/default')
@ -57,7 +76,7 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
def test_dotted_column_names_raw(self, engine, connection):
"""When Hive returns a dotted column name, and raw mode is on, nothing should be modified.
"""
row = connection.execution_options(hive_raw_colnames=True)\
row = connection.execution_options(hive_raw_colnames=True) \
.execute('SELECT * FROM one_row').fetchone()
assert row.keys() == ['one_row.number_of_rows']
assert 'number_of_rows' not in row
@ -70,9 +89,8 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
self.assertEqual(len(one_row_complex.c), 15)
self.assertIsInstance(one_row_complex.c.string, Column)
rows = one_row_complex.select().execute().fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(list(rows[0]), _ONE_ROW_COMPLEX_CONTENTS)
row = one_row_complex.select().execute().fetchone()
self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
# TODO some of these types could be filled in better
self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean)
@ -94,9 +112,8 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
@with_engine_connection
def test_type_map(self, engine, connection):
"""sqlalchemy should use the dbapi_type_map to infer types from raw queries"""
rows = connection.execute('SELECT * FROM one_row_complex').fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(list(rows[0]), _ONE_ROW_COMPLEX_CONTENTS)
row = connection.execute('SELECT * FROM one_row_complex').fetchone()
self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
@with_engine_connection
def test_reserved_words(self, engine, connection):
@ -164,7 +181,7 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
row = connection.execute(table.select()).fetchone()
self.assertEqual(row.hive_date, datetime.date(1970, 1, 1))
self.assertEqual(row.hive_decimal, decimal.Decimal(big_number))
self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123))
self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000))
table.drop()
@with_engine_connection

Просмотреть файл

@ -40,6 +40,7 @@ setup(
],
install_requires=[
'future',
'python-dateutil',
],
extras_require={
'presto': ['requests>=1.0.0'],
@ -52,7 +53,7 @@ setup(
'pytest-cov',
'requests>=1.0.0',
'sasl>=0.2.1',
'sqlalchemy>=0.8.7',
'sqlalchemy>=0.12.0',
'thrift>=0.10.0',
],
cmdclass={'test': PyTest},