зеркало из https://github.com/mozilla/PyHive.git
Fix the tests and make SQLAlchemy 1.2.8 compatible (#219)
* (More) idiomatic code for `fetchmany`, `fetchall` I went to look if `fetchall()` downloads th entire result in memory and found that it was coded in a very low-level C style, while Python has a perfectly good idiomatic approach for such things. So I thought I'd submit a fix :-) Take it or leave it, I won't insist. P.S. I'm editing this from GitHub UI, so haven't run any tests. And although I'm fairly sure it should work as is, consider it formally as a demonstration :-) * Remove unused import * Removing unreleased thrift_sasl comment Hello. I might be removing this prematurely, but `pip search thrift_sasl` now shows 0.3.0. Can this comment be removed? Also, I filled out the Dropbox Contributor License. * Add __enter__, __exit__ methods The methods allow us to use connections and cursors with `with`s statements. * Use a setter for arraysize The reason behind this is that non-integer sizes or None (or 0) do not make sense. Especially with None, which might be a common pattern, this should be prevented as it raises (cryptic) errors on the Thrift size. * Add script to generate new TCLIService files * Add decimal/timestamp conversion * Add the Travis pin * Fix Flake8 violations * Add Codecov badge
This commit is contained in:
Родитель
264cfea4d9
Коммит
e25fc8440a
13
.travis.yml
13
.travis.yml
|
@ -7,17 +7,14 @@ matrix:
|
|||
# One build pulls latest versions dynamically
|
||||
- python: 3.6
|
||||
env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE SQLALCHEMY=sqlalchemy
|
||||
# Others use pinned versions.
|
||||
- python: 3.6
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
|
||||
- python: 3.5
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.0.12
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
|
||||
- python: 3.4
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.0.12
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
|
||||
- python: 2.7
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.0.12
|
||||
# stale stuff we're still using / supporting
|
||||
- python: 2.7
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==0.8.7
|
||||
# exclude: python 3 against old libries
|
||||
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy==1.2.8
|
||||
install:
|
||||
- ./scripts/travis-install.sh
|
||||
- pip install codecov
|
||||
|
|
11
README.rst
11
README.rst
|
@ -1,5 +1,6 @@
|
|||
.. image:: https://travis-ci.org/dropbox/PyHive.svg?branch=master
|
||||
:target: https://travis-ci.org/dropbox/PyHive
|
||||
.. image:: https://img.shields.io/codecov/c/github/dropbox/PyHive.svg
|
||||
|
||||
======
|
||||
PyHive
|
||||
|
@ -100,9 +101,6 @@ PyHive works with
|
|||
- Python 2.7 / Python 3
|
||||
- For Presto: Presto install
|
||||
- For Hive: `HiveServer2 <https://cwiki.apache.org/confluence/display/Hive/Setting+up+HiveServer2>`_ daemon
|
||||
- For Python 3 + Hive + SASL, you currently need to install an unreleased version of ``thrift_sasl``
|
||||
(``pip install git+https://github.com/cloudera/thrift_sasl``).
|
||||
At the time of writing, the latest version of ``thrift_sasl`` was 0.2.1.
|
||||
|
||||
Changelog
|
||||
=========
|
||||
|
@ -137,3 +135,10 @@ Run the following in an environment with Hive/Presto::
|
|||
|
||||
WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and ``many_rows``, plus a
|
||||
database called ``pyhive_test_database``.
|
||||
|
||||
Updating TCLIService
|
||||
====================
|
||||
|
||||
The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the
|
||||
``generate.py`` file can be used: ``python generate.py <TCLIServiceURL>``. When left blank, the
|
||||
version for Hive 2.3 will be downloaded.
|
||||
|
|
|
@ -11,7 +11,6 @@ pytest-timeout==1.2.0
|
|||
# actual dependencies: let things break if a package changes
|
||||
requests>=1.0.0
|
||||
sasl>=0.2.1
|
||||
sqlalchemy>=0.8.7
|
||||
thrift>=0.10.0
|
||||
#thrift_sasl>=0.1.0
|
||||
git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
This file can be used to generate a new version of the TCLIService package
|
||||
using a TCLIService.thrift URL.
|
||||
|
||||
If no URL is specified, the file for Hive 2.3 will be downloaded.
|
||||
|
||||
Usage:
|
||||
|
||||
python generate.py THRIFT_URL
|
||||
|
||||
or
|
||||
|
||||
python generate.py
|
||||
"""
|
||||
import shutil
|
||||
import sys
|
||||
from os import path
|
||||
from urllib.request import urlopen
|
||||
import subprocess
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
PACKAGE = 'TCLIService'
|
||||
GENERATED = 'gen-py'
|
||||
|
||||
HIVE_SERVER2_URL = \
|
||||
'https://raw.githubusercontent.com/apache/hive/branch-2.3/service-rpc/if/TCLIService.thrift'
|
||||
|
||||
|
||||
def save_url(url):
|
||||
data = urlopen(url).read()
|
||||
file_path = path.join(here, url.rsplit('/', 1)[-1])
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def main(hive_server2_url):
|
||||
save_url(hive_server2_url)
|
||||
hive_server2_path = path.join(here, hive_server2_url.rsplit('/', 1)[-1])
|
||||
|
||||
subprocess.call(['thrift', '-r', '--gen', 'py', hive_server2_path])
|
||||
shutil.move(path.join(here, PACKAGE), path.join(here, PACKAGE + '.old'))
|
||||
shutil.move(path.join(here, GENERATED, PACKAGE), path.join(here, PACKAGE))
|
||||
shutil.rmtree(path.join(here, PACKAGE + '.old'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = HIVE_SERVER2_URL
|
||||
main(url)
|
|
@ -8,7 +8,6 @@ from __future__ import unicode_literals
|
|||
from builtins import bytes
|
||||
from builtins import int
|
||||
from builtins import object
|
||||
from builtins import range
|
||||
from builtins import str
|
||||
from past.builtins import basestring
|
||||
from pyhive import exc
|
||||
|
@ -16,6 +15,7 @@ import abc
|
|||
import collections
|
||||
import time
|
||||
from future.utils import with_metaclass
|
||||
from itertools import islice
|
||||
|
||||
|
||||
class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
|
||||
|
@ -124,14 +124,7 @@ class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
|
|||
"""
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
result = []
|
||||
for _ in range(size):
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
break
|
||||
else:
|
||||
result.append(one)
|
||||
return result
|
||||
return list(islice(iter(self.fetchone, None), size))
|
||||
|
||||
def fetchall(self):
|
||||
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
|
||||
|
@ -140,14 +133,7 @@ class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
|
|||
An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
|
||||
:py:meth:`execute` did not produce any result set or no call was issued yet.
|
||||
"""
|
||||
result = []
|
||||
while True:
|
||||
one = self.fetchone()
|
||||
if one is None:
|
||||
break
|
||||
else:
|
||||
result.append(one)
|
||||
return result
|
||||
return list(iter(self.fetchone, None))
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
|
|
|
@ -7,6 +7,11 @@ Many docstrings in this file are based on the PEP, which is in the public domain
|
|||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import re
|
||||
from decimal import Decimal
|
||||
|
||||
from TCLIService import TCLIService
|
||||
from TCLIService import constants
|
||||
from TCLIService import ttypes
|
||||
|
@ -31,6 +36,31 @@ paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(na
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
|
||||
|
||||
|
||||
def _parse_timestamp(value):
|
||||
if value:
|
||||
match = _TIMESTAMP_PATTERN.match(value)
|
||||
if match:
|
||||
if match.group(2):
|
||||
format = '%Y-%m-%d %H:%M:%S.%f'
|
||||
# use the pattern to truncate the value
|
||||
value = match.group()
|
||||
else:
|
||||
format = '%Y-%m-%d %H:%M:%S'
|
||||
value = datetime.datetime.strptime(value, format)
|
||||
else:
|
||||
raise Exception(
|
||||
'Cannot convert "{}" into a datetime'.format(value))
|
||||
else:
|
||||
value = None
|
||||
return value
|
||||
|
||||
|
||||
TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
|
||||
"TIMESTAMP_TYPE": _parse_timestamp}
|
||||
|
||||
|
||||
class HiveParamEscaper(common.ParamEscaper):
|
||||
def escape_string(self, item):
|
||||
|
@ -177,6 +207,14 @@ class Connection(object):
|
|||
self._transport.close()
|
||||
raise
|
||||
|
||||
def __enter__(self):
|
||||
"""Transport should already be opened by __init__"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Call close"""
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the underlying session and Thrift transport"""
|
||||
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
|
||||
|
@ -215,7 +253,7 @@ class Cursor(common.DBAPICursor):
|
|||
def __init__(self, connection, arraysize=1000):
|
||||
self._operationHandle = None
|
||||
super(Cursor, self).__init__()
|
||||
self.arraysize = arraysize
|
||||
self._arraysize = arraysize
|
||||
self._connection = connection
|
||||
|
||||
def _reset_state(self):
|
||||
|
@ -230,6 +268,19 @@ class Cursor(common.DBAPICursor):
|
|||
finally:
|
||||
self._operationHandle = None
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
"""Array size cannot be None, and should be an integer"""
|
||||
default_arraysize = 1000
|
||||
try:
|
||||
self._arraysize = int(value) or default_arraysize
|
||||
except TypeError:
|
||||
self._arraysize = default_arraysize
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
"""This read-only attribute is a sequence of 7-item sequences.
|
||||
|
@ -273,6 +324,12 @@ class Cursor(common.DBAPICursor):
|
|||
))
|
||||
return self._description
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the operation handle"""
|
||||
self._reset_state()
|
||||
|
@ -320,8 +377,10 @@ class Cursor(common.DBAPICursor):
|
|||
)
|
||||
response = self._connection.client.FetchResults(req)
|
||||
_check_status(response)
|
||||
schema = self.description
|
||||
assert not response.results.rows, 'expected data in columnar format'
|
||||
columns = map(_unwrap_column, response.results.columns)
|
||||
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
|
||||
zip(response.results.columns, schema)]
|
||||
new_data = list(zip(*columns))
|
||||
self._data += new_data
|
||||
# response.hasMoreRows seems to always be False, so we instead check the number of rows
|
||||
|
@ -402,7 +461,7 @@ for type_id in constants.PRIMITIVE_TYPES:
|
|||
#
|
||||
|
||||
|
||||
def _unwrap_column(col):
|
||||
def _unwrap_column(col, type_=None):
|
||||
"""Return a list of raw values from a TColumn instance."""
|
||||
for attr, wrapper in iteritems(col.__dict__):
|
||||
if wrapper is not None:
|
||||
|
@ -414,6 +473,9 @@ def _unwrap_column(col):
|
|||
for b in range(8):
|
||||
if byte & (1 << b):
|
||||
result[i * 8 + b] = None
|
||||
converter = TYPES_CONVERTER.get(type_, None)
|
||||
if converter and type_:
|
||||
result = [converter(row) if row else row for row in result]
|
||||
return result
|
||||
raise DataError("Got empty column value {}".format(col)) # pragma: no cover
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ which is released under the MIT license.
|
|||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
|
||||
import re
|
||||
|
@ -24,6 +25,9 @@ from sqlalchemy.sql.compiler import SQLCompiler
|
|||
from pyhive import hive
|
||||
from pyhive.common import UniversalSet
|
||||
|
||||
from dateutil.parser import parse
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class HiveStringTypeBase(types.TypeDecorator):
|
||||
"""Translates strings returned by Thrift into something else"""
|
||||
|
@ -40,6 +44,22 @@ class HiveDate(HiveStringTypeBase):
|
|||
def process_result_value(self, value, dialect):
|
||||
return processors.str_to_date(value)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.date()
|
||||
elif isinstance(value, datetime.date):
|
||||
return value
|
||||
elif value is not None:
|
||||
return parse(value).date()
|
||||
else:
|
||||
return None
|
||||
|
||||
return process
|
||||
|
||||
def adapt(self, impltype, **kwargs):
|
||||
return self.impl
|
||||
|
||||
|
||||
class HiveTimestamp(HiveStringTypeBase):
|
||||
"""Translates timestamp strings to datetime objects"""
|
||||
|
@ -48,16 +68,44 @@ class HiveTimestamp(HiveStringTypeBase):
|
|||
def process_result_value(self, value, dialect):
|
||||
return processors.str_to_datetime(value)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value
|
||||
elif value is not None:
|
||||
return parse(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
return process
|
||||
|
||||
def adapt(self, impltype, **kwargs):
|
||||
return self.impl
|
||||
|
||||
|
||||
class HiveDecimal(HiveStringTypeBase):
|
||||
"""Translates strings to decimals"""
|
||||
impl = types.DECIMAL
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
if value is not None:
|
||||
return decimal.Decimal(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if isinstance(value, Decimal):
|
||||
return value
|
||||
elif value is not None:
|
||||
return Decimal(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
return process
|
||||
|
||||
def adapt(self, impltype, **kwargs):
|
||||
return self.impl
|
||||
|
||||
|
||||
class HiveIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
|
@ -195,11 +243,6 @@ class HiveDialect(default.DefaultDialect):
|
|||
returns_unicode_strings = True
|
||||
description_encoding = None
|
||||
supports_multivalues_insert = True
|
||||
dbapi_type_map = {
|
||||
'DATE_TYPE': HiveDate(),
|
||||
'TIMESTAMP_TYPE': HiveTimestamp(),
|
||||
'DECIMAL_TYPE': HiveDecimal(),
|
||||
}
|
||||
type_compiler = HiveTypeCompiler
|
||||
|
||||
@classmethod
|
||||
|
@ -215,7 +258,7 @@ class HiveDialect(default.DefaultDialect):
|
|||
'database': url.database or 'default',
|
||||
}
|
||||
kwargs.update(url.query)
|
||||
return ([], kwargs)
|
||||
return [], kwargs
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
# Equivalent to SHOW DATABASES
|
||||
|
@ -276,6 +319,7 @@ class HiveDialect(default.DefaultDialect):
|
|||
except KeyError:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name))
|
||||
coltype = types.NullType
|
||||
|
||||
result.append({
|
||||
'name': col_name,
|
||||
'type': coltype,
|
||||
|
|
|
@ -104,7 +104,7 @@ class PrestoDialect(default.DefaultDialect):
|
|||
kwargs['schema'] = db_parts[1]
|
||||
else:
|
||||
raise ValueError("Unexpected database format {}".format(url.database))
|
||||
return ([], kwargs)
|
||||
return [], kwargs
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
|
||||
|
|
|
@ -64,9 +64,9 @@ class DBAPITestCase(with_metaclass(abc.ABCMeta, object)):
|
|||
def test_description_failed(self, cursor):
|
||||
try:
|
||||
cursor.execute('blah_blah')
|
||||
self.assertIsNone(cursor.description)
|
||||
except exc.DatabaseError:
|
||||
pass
|
||||
self.assertIsNone(cursor.description)
|
||||
|
||||
@with_cursor
|
||||
def test_bad_query(self, cursor):
|
||||
|
|
|
@ -8,11 +8,13 @@ from __future__ import absolute_import
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from decimal import Decimal
|
||||
|
||||
import mock
|
||||
import sasl
|
||||
|
@ -72,13 +74,13 @@ class TestHive(unittest.TestCase, DBAPITestCase):
|
|||
0.5,
|
||||
0.25,
|
||||
'a string',
|
||||
'1970-01-01 00:00:00.0',
|
||||
datetime.datetime(1970, 1, 1, 0, 0),
|
||||
b'123',
|
||||
'[1,2]',
|
||||
'{1:2,3:4}',
|
||||
'{"a":1,"b":2}',
|
||||
'{0:1}',
|
||||
'0.1',
|
||||
Decimal('0.1'),
|
||||
)]
|
||||
self.assertEqual(rows, expected)
|
||||
# catch unicode/str
|
||||
|
@ -159,7 +161,7 @@ class TestHive(unittest.TestCase, DBAPITestCase):
|
|||
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
|
||||
_restart_hs2()
|
||||
with contextlib.closing(hive.connect(
|
||||
host=_HOST, username='existing', auth='LDAP', password='testpw')
|
||||
host=_HOST, username='existing', auth='LDAP', password='testpw')
|
||||
) as connection:
|
||||
with contextlib.closing(connection.cursor()) as cursor:
|
||||
cursor.execute('SELECT * FROM one_row')
|
||||
|
@ -209,6 +211,7 @@ class TestHive(unittest.TestCase, DBAPITestCase):
|
|||
sasl_client.setAttr('password', 'x')
|
||||
sasl_client.init()
|
||||
return sasl_client
|
||||
|
||||
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
|
||||
conn = hive.connect(thrift_transport=transport)
|
||||
with contextlib.closing(conn):
|
||||
|
|
|
@ -130,6 +130,7 @@ class TestPresto(unittest.TestCase, DBAPITestCase):
|
|||
|
||||
def fail(*args, **kwargs):
|
||||
self.fail("Should not need requests.get after done polling") # pragma: no cover
|
||||
|
||||
with mock.patch('requests.get', fail):
|
||||
self.assertEqual(cursor.fetchall(), [(1,)])
|
||||
|
||||
|
|
|
@ -36,6 +36,25 @@ _ONE_ROW_COMPLEX_CONTENTS = [
|
|||
]
|
||||
|
||||
|
||||
# [
|
||||
# ('boolean', 'boolean', ''),
|
||||
# ('tinyint', 'tinyint', ''),
|
||||
# ('smallint', 'smallint', ''),
|
||||
# ('int', 'int', ''),
|
||||
# ('bigint', 'bigint', ''),
|
||||
# ('float', 'float', ''),
|
||||
# ('double', 'double', ''),
|
||||
# ('string', 'string', ''),
|
||||
# ('timestamp', 'timestamp', ''),
|
||||
# ('binary', 'binary', ''),
|
||||
# ('array', 'array<int>', ''),
|
||||
# ('map', 'map<int,int>', ''),
|
||||
# ('struct', 'struct<a:int,b:int>', ''),
|
||||
# ('union', 'uniontype<int,string>', ''),
|
||||
# ('decimal', 'decimal(10,1)', '')
|
||||
# ]
|
||||
|
||||
|
||||
class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
|
||||
def create_engine(self):
|
||||
return create_engine('hive://localhost:10000/default')
|
||||
|
@ -57,7 +76,7 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
|
|||
def test_dotted_column_names_raw(self, engine, connection):
|
||||
"""When Hive returns a dotted column name, and raw mode is on, nothing should be modified.
|
||||
"""
|
||||
row = connection.execution_options(hive_raw_colnames=True)\
|
||||
row = connection.execution_options(hive_raw_colnames=True) \
|
||||
.execute('SELECT * FROM one_row').fetchone()
|
||||
assert row.keys() == ['one_row.number_of_rows']
|
||||
assert 'number_of_rows' not in row
|
||||
|
@ -70,9 +89,8 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
|
|||
one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
|
||||
self.assertEqual(len(one_row_complex.c), 15)
|
||||
self.assertIsInstance(one_row_complex.c.string, Column)
|
||||
rows = one_row_complex.select().execute().fetchall()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(list(rows[0]), _ONE_ROW_COMPLEX_CONTENTS)
|
||||
row = one_row_complex.select().execute().fetchone()
|
||||
self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
|
||||
|
||||
# TODO some of these types could be filled in better
|
||||
self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean)
|
||||
|
@ -94,9 +112,8 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
|
|||
@with_engine_connection
|
||||
def test_type_map(self, engine, connection):
|
||||
"""sqlalchemy should use the dbapi_type_map to infer types from raw queries"""
|
||||
rows = connection.execute('SELECT * FROM one_row_complex').fetchall()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(list(rows[0]), _ONE_ROW_COMPLEX_CONTENTS)
|
||||
row = connection.execute('SELECT * FROM one_row_complex').fetchone()
|
||||
self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
|
||||
|
||||
@with_engine_connection
|
||||
def test_reserved_words(self, engine, connection):
|
||||
|
@ -164,7 +181,7 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
|
|||
row = connection.execute(table.select()).fetchone()
|
||||
self.assertEqual(row.hive_date, datetime.date(1970, 1, 1))
|
||||
self.assertEqual(row.hive_decimal, decimal.Decimal(big_number))
|
||||
self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123))
|
||||
self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000))
|
||||
table.drop()
|
||||
|
||||
@with_engine_connection
|
||||
|
|
3
setup.py
3
setup.py
|
@ -40,6 +40,7 @@ setup(
|
|||
],
|
||||
install_requires=[
|
||||
'future',
|
||||
'python-dateutil',
|
||||
],
|
||||
extras_require={
|
||||
'presto': ['requests>=1.0.0'],
|
||||
|
@ -52,7 +53,7 @@ setup(
|
|||
'pytest-cov',
|
||||
'requests>=1.0.0',
|
||||
'sasl>=0.2.1',
|
||||
'sqlalchemy>=0.8.7',
|
||||
'sqlalchemy>=0.12.0',
|
||||
'thrift>=0.10.0',
|
||||
],
|
||||
cmdclass={'test': PyTest},
|
||||
|
|
Загрузка…
Ссылка в новой задаче