This commit is contained in:
Jing Wang 2016-05-21 15:58:06 -07:00
Родитель f87469058b
Коммит a134942551
11 изменённых файлов: 44 добавлений и 29 удалений

Просмотреть файл

@ -1,6 +1,7 @@
mock>=1.0.0
pytest
pytest>=2.9.0
pytest-cov
pytest-flake8
pytest-random
requests>=1.0.0
sasl>=0.1.3

Просмотреть файл

@ -13,7 +13,6 @@ from past.builtins import basestring
from pyhive import exc
import abc
import collections
import sys
import time
from future.utils import with_metaclass

Просмотреть файл

@ -13,7 +13,7 @@ from TCLIService import ttypes
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
from pyhive.exc import *
from pyhive.exc import * # noqa
import contextlib
import getpass
import logging
@ -42,7 +42,8 @@ class HiveParamEscaper(common.ParamEscaper):
# string formatting here.
if isinstance(item, str):
item = item.decode('utf-8')
return "'{}'".format(item
return "'{}'".format(
item
.replace('\\', '\\\\')
.replace("'", "\\'")
.replace('\r', '\\r')
@ -65,7 +66,8 @@ def connect(*args, **kwargs):
class Connection(object):
"""Wraps a Thrift session"""
def __init__(self, host, port=10000, username=None, database='default', auth='NONE', configuration=None):
def __init__(self, host, port=10000, username=None, database='default', auth='NONE',
configuration=None):
"""Connect to HiveServer2
:param auth: The value of hive.server2.authentication used by HiveServer2
@ -89,22 +91,24 @@ class Connection(object):
# PLAIN corresponds to hive.server2.authentication=NONE in hive-site.xml
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, b'PLAIN', socket)
else:
raise NotImplementedError("Only NONE & NOSASL authentication are supported, got {}".format(auth))
raise NotImplementedError(
"Only NONE & NOSASL authentication are supported, got {}".format(auth))
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
self._client = TCLIService.Client(protocol)
protocolVersion = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
try:
self._transport.open()
open_session_req = ttypes.TOpenSessionReq(
client_protocol=ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1,
client_protocol=protocolVersion,
configuration=configuration,
)
response = self._client.OpenSession(open_session_req)
_check_status(response)
assert(response.sessionHandle is not None), "Expected a session from OpenSession"
assert response.sessionHandle is not None, "Expected a session from OpenSession"
self._sessionHandle = response.sessionHandle
assert(response.serverProtocolVersion == ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1), \
assert response.serverProtocolVersion == protocolVersion, \
"Unable to handle protocol version {}".format(response.serverProtocolVersion)
with contextlib.closing(self.cursor()) as cursor:
cursor.execute('USE `{}`'.format(database))
@ -246,7 +250,7 @@ class Cursor(common.DBAPICursor):
response = self._connection.client.FetchResults(req)
_check_status(response)
# response.hasMoreRows seems to always be False, so we instead check the number of rows
#if not response.hasMoreRows:
# if not response.hasMoreRows:
if not response.results.rows:
self._state = self._STATE_FINISHED
for row in response.results.rows:

Просмотреть файл

@ -11,7 +11,7 @@ from builtins import object
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
from pyhive.exc import *
from pyhive.exc import * # noqa
import base64
import getpass
import logging
@ -126,9 +126,9 @@ class Cursor(common.DBAPICursor):
section below.
"""
# Sleep until we're done or we got the columns
self._fetch_while(lambda:
self._columns is None
and self._state not in (self._STATE_NONE, self._STATE_FINISHED)
self._fetch_while(
lambda: self._columns is None and
self._state not in (self._STATE_NONE, self._STATE_FINISHED)
)
if self._columns is None:
return None

Просмотреть файл

@ -23,6 +23,7 @@ try:
except ImportError:
from sqlalchemy.sql.compiler import DefaultCompiler as SQLCompiler
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = UniversalSet()
@ -97,7 +98,10 @@ class PrestoDialect(default.DefaultDialect):
# call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
# presto.DatabaseError here.
# Does the table exist?
msg = e.args[0].get('message') if len(e.args) > 0 and isinstance(e.args[0], dict) else None
msg = (
e.args[0].get('message') if len(e.args) > 0 and isinstance(e.args[0], dict)
else None
)
regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
if msg and re.search(regex, msg):
raise exc.NoSuchTableError(table_name)

Просмотреть файл

@ -43,9 +43,11 @@ class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)):
@with_engine_connection
def test_reflect_no_such_table(self, engine, connection):
"""reflecttable should throw an exception on an invalid table"""
self.assertRaises(NoSuchTableError,
self.assertRaises(
NoSuchTableError,
lambda: Table('this_does_not_exist', MetaData(bind=engine), autoload=True))
self.assertRaises(NoSuchTableError,
self.assertRaises(
NoSuchTableError,
lambda: Table('this_does_not_exist', MetaData(bind=engine),
schema='also_does_not_exist', autoload=True))

Просмотреть файл

@ -53,8 +53,8 @@ class TestPresto(unittest.TestCase, DBAPITestCase):
('array', 'array(bigint)', None, None, None, None, True),
('map', 'map(bigint,bigint)', None, None, None, None, True),
('struct', "row(bigint,bigint)('a','b')", None, None, None, None, True),
#('union', 'varchar', None, None, None, None, True),
#('decimal', 'double', None, None, None, None, True),
# ('union', 'varchar', None, None, None, None, True),
# ('decimal', 'double', None, None, None, None, True),
])
self.assertEqual(cursor.fetchall(), [[
True,
@ -70,8 +70,8 @@ class TestPresto(unittest.TestCase, DBAPITestCase):
[1, 2],
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
#'{0:1}',
#0.1,
# '{0:1}',
# 0.1,
]])
def test_noops(self):

Просмотреть файл

@ -5,8 +5,8 @@ from distutils.version import StrictVersion
from pyhive.sqlalchemy_hive import HiveDate
from pyhive.sqlalchemy_hive import HiveDecimal
from pyhive.sqlalchemy_hive import HiveTimestamp
from pyhive.tests.sqlachemy_test_case import SqlAlchemyTestCase
from pyhive.tests.sqlachemy_test_case import with_engine_connection
from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
from pyhive.tests.sqlalchemy_test_case import with_engine_connection
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import Column
from sqlalchemy.schema import MetaData

Просмотреть файл

@ -1,8 +1,8 @@
from __future__ import absolute_import
from __future__ import unicode_literals
from builtins import str
from pyhive.tests.sqlachemy_test_case import SqlAlchemyTestCase
from pyhive.tests.sqlachemy_test_case import with_engine_connection
from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
from pyhive.tests.sqlalchemy_test_case import with_engine_connection
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import Column
from sqlalchemy.schema import MetaData
@ -45,8 +45,8 @@ class TestSqlAlchemyPresto(unittest.TestCase, SqlAlchemyTestCase):
[1, 2],
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
#'{0:1}',
#0.1,
# '{0:1}',
# 0.1,
])
def test_url_default(self):

Просмотреть файл

@ -2,6 +2,11 @@
tag_build = dev
[pytest]
addopts = --random --tb=short --cov pyhive --cov-report html --cov-report term
addopts = --random --tb=short --cov pyhive --cov-report html --cov-report term --flake8
norecursedirs = env
python_files = test_*.py
flake8-max-line-length = 100
flake8-ignore =
TCLIService/*.py ALL
pyhive/sqlalchemy_backports.py ALL
presto-server-*/** ALL

Просмотреть файл

@ -13,7 +13,7 @@ class PyTest(TestCommand):
self.test_suite = True
def run_tests(self):
#import here, cause outside the eggs aren't loaded
# import here, cause outside the eggs aren't loaded
import pytest
errno = pytest.main(self.test_args)
sys.exit(errno)