* Use MSAL library support instead of ADAL.

---------

Co-authored-by: Ray Luo <rayluo.mba@gmail.com>
This commit is contained in:
akharit 2023-04-24 12:30:51 -07:00 коммит произвёл GitHub
Родитель 7f6e5ea7b1
Коммит b44a68d213
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 75 добавлений и 179 удалений

Просмотреть файл

@ -3,8 +3,9 @@
Release History
===============
0.0.53 (2022-10-26)
0.0.53 (2023-04-11)
+++++++++++++++++++
* Add MSAL support. Remove ADAL support
* Suppress deprecation warning when detecting pyopenssl existence.
0.0.52 (2020-11-25)

Просмотреть файл

@ -23,14 +23,10 @@ Manually (bleeding edge):
* Download the repo from [https://github.com/Azure/azure-data-lake-store-python](https://github.com/Azure/azure-data-lake-store-python)
* checkout the `dev` branch
* install the requirements (`pip install -r dev_requirements.txt`)
* install in develop mode (`python setup.py develop`)
* optionally: build the documentation (including this page) by running `make html` in the docs directory.
## Auth
Although users can generate and supply their own tokens to the base file-system

Просмотреть файл

@ -27,13 +27,14 @@ if sys.version_info >= (3, 4):
else:
import urllib
from .retry import ExponentialRetryPolicy, retry_decorator_for_auth
from .retry import ExponentialRetryPolicy
# 3rd party imports
import adal
import msal
import requests
import requests.exceptions
_http_cache = {} # Useful for MSAL. https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.params.http_cache
# this is required due to github issue, to ensure we don't lose perf from openPySSL: https://github.com/pyca/pyopenssl/issues/625
def enforce_no_py_open_ssl():
@ -118,13 +119,8 @@ def auth(tenant_id=None, username=None,
if not authority:
authority = 'https://login.microsoftonline.com/'
if not tenant_id:
tenant_id = os.environ.get('azure_tenant_id', "common")
context = adal.AuthenticationContext(authority +
tenant_id)
if tenant_id is None or client_id is None:
raise ValueError("tenant_id and client_id must be supplied for authentication")
@ -136,60 +132,59 @@ def auth(tenant_id=None, username=None,
if not client_secret:
client_secret = os.environ.get('azure_client_secret', None)
# You can explicitly authenticate with 2fa, or pass in nothing to the auth call
# and the user will be prompted to login interactively through a browser.
@retry_decorator_for_auth(retry_policy=retry_policy)
scopes = kwargs.get('scopes', ["https://datalake.azure.net/.default"])
def get_token_internal():
# Internal function used so as to use retry decorator
if require_2fa or (username is None and password is None and client_secret is None):
code = context.acquire_user_code(resource, client_id)
print(code['message'])
out = context.acquire_token_with_device_code(resource, code, client_id)
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
flow = contextPub.initiate_device_flow(scopes=scopes)
print(flow['message'])
out = contextPub.acquire_token_by_device_flow(flow)
elif username and password:
out = context.acquire_token_with_username_password(resource, username,
password, client_id)
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
out = contextPub.acquire_token_by_username_password(username=username, password=password, scopes=scopes)
elif client_id and client_secret:
out = context.acquire_token_with_client_credentials(resource, client_id,
client_secret)
contextClient = msal.ConfidentialClientApplication(client_id=client_id, authority=authority+tenant_id, client_credential=client_secret, http_cache=_http_cache)
out = contextClient.acquire_token_for_client(scopes=scopes)
# for service principal, we store the secret in the credential object for use when refreshing.
out.update({'secret': client_secret})
else:
raise ValueError("No authentication method found for credentials")
return out
out = get_token_internal()
out.update({'access': out['accessToken'], 'resource': resource,
'refresh': out.get('refreshToken', False),
'time': time.time(), 'tenant': tenant_id, 'client': client_id})
if 'error' in out:
msg = "MSAL Error: "+out.get('error_description', "")
err = DatalakeRESTException(msg)
logger.log(logging.ERROR, msg)
raise err
out.update({'access_token': out['access_token'], 'access': out['access_token'], 'resource': resource,
'refresh': out.get('refresh_token', False),
'time': time.time(), 'tenant': tenant_id, 'client': client_id, 'scopes':scopes})
return DataLakeCredential(out)
class DataLakeCredential:
# Be careful modifying this. DataLakeCredential is a general class in azure, and we have to maintain parity.
def __init__(self, token):
self.token = token
def signed_session(self):
# type: () -> requests.Session
"""Create requests session with any required auth headers applied.
:rtype: requests.Session
"""
session = requests.Session()
if time.time() - self.token['time'] > self.token['expiresIn'] - 100:
if time.time() - self.token['time'] > self.token['expires_in'] - 100:
self.refresh_token()
scheme, token = self.token['tokenType'], self.token['access']
session = requests.Session()
scheme, token = self.token['token_type'], self.token['access_token']
header = "{} {}".format(scheme, token)
session.headers['Authorization'] = header
return session
def refresh_token(self, authority=None):
""" Refresh an expired authorization token
Parameters
----------
authority: string
@ -201,25 +196,31 @@ class DataLakeCredential:
if not authority:
authority = 'https://login.microsoftonline.com/'
context = adal.AuthenticationContext(authority +
self.token['tenant'])
tenant_id = self.token['tenant']
scopes = self.token['scopes']
if self.token.get('secret') and self.token.get('client'):
out = context.acquire_token_with_client_credentials(self.token['resource'],
self.token['client'],
self.token['secret'])
out.update({'secret': self.token['secret']})
client_id = self.token['client']
client_secret = self.token['secret']
contextClient = msal.ConfidentialClientApplication(client_id=client_id, authority=authority+tenant_id, client_credential=client_secret, http_cache=_http_cache)
out = contextClient.acquire_token_for_client(scopes=scopes)
out.update({'secret': client_secret})
else:
out = context.acquire_token_with_refresh_token(self.token['refresh'],
client_id=self.token['client'],
resource=self.token['resource'])
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
out = contextPub.client.obtain_token_by_refresh_token(self.token['refresh'], scopes=scopes)
if 'error' in out:
msg = "MSAL Error: "+out.get('error_description', "")
err = DatalakeRESTException(msg)
logger.log(logging.ERROR, msg)
raise err
# common items to update
out.update({'access': out['accessToken'],
out.update({'access_token': out['access_token'], 'access': out['access_token'],
'time': time.time(), 'tenant': self.token['tenant'],
'resource': self.token['resource'], 'client': self.token['client']})
'resource': self.token['resource'], 'client': self.token['client'], 'scopes':self.token['scopes']})
self.token = out
class DatalakeRESTInterface:
""" Call factory for webHDFS endpoints on ADLS
@ -228,7 +229,7 @@ class DatalakeRESTInterface:
store_name: str
The name of the Data Lake Store account to execute operations against.
token: dict
from `auth()` or `refresh_token()` or other ADAL source
from `auth()` or `refresh_token()` or other MSAL source
url_suffix: str (None)
Domain to send REST requests to. The end-point URL is constructed
using this and the store_name. If None, use default.
@ -309,13 +310,10 @@ class DatalakeRESTInterface:
return s
def _check_token(self, retry_policy= None):
@retry_decorator_for_auth(retry_policy=retry_policy)
def check_token_internal():
cur_session = self.token.signed_session()
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
self.head = {'Authorization': cur_session.headers['Authorization']}
self.local.session = None
check_token_internal()
cur_session = self.token.signed_session()
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
self.head = {'Authorization': cur_session.headers['Authorization']}
self.local.session = None
def _log_request(self, method, url, op, path, params, headers, retry_count):
msg = u"HTTP Request\n{} {}\n".format(method.upper(), url)
@ -498,7 +496,7 @@ class DatalakeRESTInterface:
"""
Not yet implemented (or not applicable)
http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/WebHDFS.html
https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/WebHDFS.html
GETFILECHECKSUM
GETHOMEDIRECTORY

Просмотреть файл

@ -22,7 +22,7 @@ GLOBAL_EXCEPTION_LOCK = threading.Lock()
def monitor_exception(exception_queue, process_ids):
global GLOBAL_EXCEPTION
logger = logging.getLogger(__name__)
logger = logging.getLogger("azure.datalake.store")
while True:
try:
@ -53,8 +53,8 @@ def log_listener_process(queue):
queue.task_done()
if record == END_QUEUE_SENTINEL: # We send this as a sentinel to tell the listener to quit.
break
logger = logging.getLogger(record.name)
logger.handlers.clear()
logger = logging.getLogger("azure.datalake.store")
#logger.handlers.clear()
logger.handle(record) # No level or filter logic applied - just do it!
except Empty: # Try again
pass
@ -65,7 +65,7 @@ def log_listener_process(queue):
def multi_processor_change_acl(adl, path=None, method_name="", acl_spec="", number_of_sub_process=None):
logger = logging.getLogger(__name__)
logger = logging.getLogger("azure.datalake.store")
def launch_processes(number_of_processes):
if number_of_processes is None:
@ -152,8 +152,8 @@ def multi_processor_change_acl(adl, path=None, method_name="", acl_spec="", numb
def processor(adl, file_path_queue, finish_queue_processing_flag, method_name, acl_spec, log_queue, exception_queue):
logger = logging.getLogger(__name__)
logger = logging.getLogger("azure.datalake.store")
logger.setLevel(logging.DEBUG)
removed_default_acl_spec = ",".join([x for x in acl_spec.split(',') if not x.lower().startswith("default")])
try:

Просмотреть файл

@ -74,55 +74,3 @@ class ExponentialRetryPolicy(RetryPolicy):
def __backoff(self):
time.sleep(self.exponential_retry_interval)
self.exponential_retry_interval *= self.exponential_factor
def retry_decorator_for_auth(retry_policy = None):
import adal
from requests import HTTPError
if retry_policy is None:
retry_policy = ExponentialRetryPolicy(max_retries=2)
def deco_retry(func):
@wraps(func)
def f_retry(*args, **kwargs):
retry_count = -1
while True:
last_exception = None
retry_count += 1
try:
out = func(*args, **kwargs)
except (adal.adal_error.AdalError, HTTPError) as e:
# ADAL error corresponds to everything but 429, which bubbles up HTTP error.
last_exception = e
logger.exception("Retry count " + str(retry_count) + "Exception :" + str(last_exception))
# We don't want to stop retry for any error in parsing the exception. This is a GET operation.
try:
if hasattr(last_exception, 'error_response'): # ADAL exception
response = response_from_adal_exception(last_exception)
if hasattr(last_exception, 'response'): # HTTP exception i.e 429
response = last_exception.response
except:
pass
request_successful = last_exception is None or (response is not None and response.status_code == 401) # 401 = Invalid credentials
if request_successful or not retry_policy.should_retry(response, last_exception, retry_count):
break
if last_exception is not None:
raise last_exception
return out
return f_retry
return deco_retry
def response_from_adal_exception(e):
import re
from collections import namedtuple
http_code = re.search(r"http error: (\d+)", str(e))
if http_code is not None: # Add status_code to response object for use in should_retry
status_code = [int(http_code.group(1))]
Response = namedtuple("Response", ['status_code'])
response = Response(
*status_code) # Construct response object with adal exception response and http code
return response

Просмотреть файл

@ -179,7 +179,7 @@ class CountUpDownLatch:
self.lock.acquire()
self.val -= 1
if self.val <= 0:
self.lock.notifyAll()
self.lock.notify_all()
self.lock.release()
def total_processed(self):

Просмотреть файл

@ -22,7 +22,7 @@ setup(name='azure-datalake-store',
description='Azure Data Lake Store Filesystem Client Library for Python',
url='https://github.com/Azure/azure-data-lake-store-python',
author='Microsoft Corporation',
author_email='ptvshelp@microsoft.com',
author_email='Akshat.Harit@microsoft.com',
license='MIT License',
keywords='azure',
classifiers=[
@ -36,6 +36,10 @@ setup(name='azure-datalake-store',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'License :: OSI Approved :: MIT License',
],
packages=find_packages(exclude=['tests',
@ -44,7 +48,7 @@ setup(name='azure-datalake-store',
]),
install_requires=[
'cffi',
'adal>=0.4.2',
'msal>=1.16.0,<2', # http_cache was introduced in MSAL 1.16.0
'requests>=2.20.0',
],
extras_require={

Просмотреть файл

@ -15,7 +15,7 @@ Example: How will the logic behave in case of specific error from server side.
This was introduced to test the Retry Policy but can be carefully used for other tests as well.
"""
import pytest
import responses
from requests import ConnectionError, ConnectTimeout, ReadTimeout, Timeout, HTTPError
@ -110,54 +110,3 @@ def __test_retry_error(azure,
assert is_exception_expected
@responses.activate
def __test_retry_auth(error_code, error_string, is_exception_expected, total_tries=4, last_try_status=200,
last_try_body=None):
import re, adal
end_point_discovery = re.compile("https:\/\/login\.microsoftonline\.com\/common\/discovery\/"
"instance\?authorization_endpoint=.+")
mock_url_auth = "https://login.microsoftonline.com/" + settings.TENANT_ID + "/oauth2/token"
body_discovery = r'{"tenant_discovery_endpoint":"https://login.microsoftonline.com/' + TENANT_ID + \
'/.well-known/openid-configuration"}'
body_error = r'{"error":"' + error_string + r'","error_description":"0","error_codes":[0],"timestamp":"0",' \
r'"trace_id":"0","correlation_id":"0"}'
if last_try_body is None:
last_try_body = r'{"token_type":"Bearer","expires_in":"1","ext_expires_in":"1","expires_on":"1",' \
r'"not_before":"1","resource":"https://datalake.azure.net/","access_token":"a"}'
while total_tries > 0:
responses.add(responses.GET, end_point_discovery, body=body_discovery, status=200)
responses.add(responses.POST, mock_url_auth, body=body_error, status=error_code)
total_tries -= 1
responses.add(responses.GET, end_point_discovery, body=body_discovery, status=200)
responses.add(responses.POST, mock_url_auth, body=last_try_body, status=last_try_status)
try:
token = auth(tenant_id=TENANT_ID, client_secret='GARBAGE', client_id=CLIENT_ID)
assert isinstance(token, DataLakeCredential)
assert not is_exception_expected
except (HTTPError, adal.adal_error.AdalError):
assert is_exception_expected
def test_retry_auth_401():
__test_retry_auth(error_code=401, error_string=r'invalid_client', total_tries=1, is_exception_expected=True)
def test_retry_auth_400():
__test_retry_auth(error_code=400, error_string=r'invalid_client', total_tries=1, is_exception_expected=False)
def test_retry_auth_104():
__test_retry_auth(error_code=104, error_string=r'Connection Error', total_tries=1, is_exception_expected=False )
__test_retry_auth(error_code=104, error_string=r'Connection Error', is_exception_expected=True, total_tries=6)
def test_retry_auth_429():
__test_retry_auth(error_code=429, error_string=r'Too many requests', total_tries=2, is_exception_expected=False)
__test_retry_auth(error_code=429, error_string=r'Too many requests', is_exception_expected=True, total_tries=6)
def test_retry_auth_501():
__test_retry_auth(error_code=501, error_string=r'invalid_client', total_tries=1, is_exception_expected=False)

Просмотреть файл

@ -42,7 +42,9 @@ def linecount(infile):
@contextmanager
def setup_tree(azure):
def setup_tree(azure, test_dir=None):
if test_dir == None:
test_dir = working_dir()
for directory in ['', 'data/a', 'data/b']:
azure.mkdir(test_dir / directory)
for filename in ['x.csv', 'y.csv', 'z.txt']:
@ -55,9 +57,7 @@ def setup_tree(azure):
try:
yield
finally:
for path in azure.ls(test_dir, invalidate_cache=False):
if azure.exists(path, invalidate_cache=False):
azure.rm(path, recursive=True)
azure.rm(test_dir, recursive=True)
def create_remote_csv(fs, name, columns, colwidth, lines):
@ -246,7 +246,7 @@ def test_download_overwrite(tempdir, azure):
with pytest.raises(OSError) as e:
ADLDownloader(azure, test_dir, tempdir, 1, 2**24, run=False)
assert tempdir in str(e)
assert os.path.split(tempdir)[1] in str(e)
@my_vcr.use_cassette
@ -506,15 +506,13 @@ def test_set_acl_recusrive(azure):
def check_acl_perms(path, permission):
current_acl = azure.get_acl_status(path)
acl_user_entry = [s for s in current_acl['entries'] if acluser in s]
assert len(acl_user_entry) == 1
assert len(acl_user_entry) == 1, "Path: "+path + " Acls: " + str(acl_user_entry)
assert acl_user_entry[0].split(':')[-1] == permission
files = list(azure.walk(test_dir))
directories = list(set([x[0] for x in map(os.path.split, files)]))
permission = "rwx"
azure.set_acl(test_dir, acl_spec=set_acl_base + "user:"+acluser+":"+permission, recursive=True, number_of_sub_process=2)
for path in files+directories:
check_acl_perms(path, permission)

Просмотреть файл

@ -113,11 +113,13 @@ def create_files(azure, number_of_files, root_path = working_dir(), prefix=''):
@contextmanager
def azure_teardown(fs):
try:
fs.mkdir(working_dir())
yield
finally:
# this is a best effort. If there is an error attempting to delete during cleanup,
# print it, but it should not cause the test to fail.
try:
fs.rm(working_dir(), recursive=True)
for path in fs.ls(working_dir(), invalidate_cache=False):
if fs.exists(path, invalidate_cache=False):
fs.rm(path, recursive=True)