Add MSAL support (#330)
* Use MSAL library support instead of ADAL. --------- Co-authored-by: Ray Luo <rayluo.mba@gmail.com>
This commit is contained in:
Родитель
7f6e5ea7b1
Коммит
b44a68d213
|
@ -3,8 +3,9 @@
|
|||
Release History
|
||||
===============
|
||||
|
||||
0.0.53 (2022-10-26)
|
||||
0.0.53 (2023-04-11)
|
||||
+++++++++++++++++++
|
||||
* Add MSAL support. Remove ADAL support
|
||||
* Suppress deprecation warning when detecting pyopenssl existence.
|
||||
|
||||
0.0.52 (2020-11-25)
|
||||
|
|
|
@ -23,14 +23,10 @@ Manually (bleeding edge):
|
|||
|
||||
* Download the repo from [https://github.com/Azure/azure-data-lake-store-python](https://github.com/Azure/azure-data-lake-store-python)
|
||||
|
||||
* checkout the `dev` branch
|
||||
|
||||
* install the requirements (`pip install -r dev_requirements.txt`)
|
||||
|
||||
* install in develop mode (`python setup.py develop`)
|
||||
|
||||
* optionally: build the documentation (including this page) by running `make html` in the docs directory.
|
||||
|
||||
## Auth
|
||||
|
||||
Although users can generate and supply their own tokens to the base file-system
|
||||
|
|
|
@ -27,13 +27,14 @@ if sys.version_info >= (3, 4):
|
|||
else:
|
||||
import urllib
|
||||
|
||||
from .retry import ExponentialRetryPolicy, retry_decorator_for_auth
|
||||
from .retry import ExponentialRetryPolicy
|
||||
|
||||
# 3rd party imports
|
||||
import adal
|
||||
import msal
|
||||
import requests
|
||||
import requests.exceptions
|
||||
|
||||
_http_cache = {} # Useful for MSAL. https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.params.http_cache
|
||||
|
||||
# this is required due to github issue, to ensure we don't lose perf from openPySSL: https://github.com/pyca/pyopenssl/issues/625
|
||||
def enforce_no_py_open_ssl():
|
||||
|
@ -118,13 +119,8 @@ def auth(tenant_id=None, username=None,
|
|||
if not authority:
|
||||
authority = 'https://login.microsoftonline.com/'
|
||||
|
||||
|
||||
if not tenant_id:
|
||||
tenant_id = os.environ.get('azure_tenant_id', "common")
|
||||
|
||||
context = adal.AuthenticationContext(authority +
|
||||
tenant_id)
|
||||
|
||||
if tenant_id is None or client_id is None:
|
||||
raise ValueError("tenant_id and client_id must be supplied for authentication")
|
||||
|
||||
|
@ -136,60 +132,59 @@ def auth(tenant_id=None, username=None,
|
|||
|
||||
if not client_secret:
|
||||
client_secret = os.environ.get('azure_client_secret', None)
|
||||
|
||||
# You can explicitly authenticate with 2fa, or pass in nothing to the auth call
|
||||
# and the user will be prompted to login interactively through a browser.
|
||||
|
||||
@retry_decorator_for_auth(retry_policy=retry_policy)
|
||||
|
||||
scopes = kwargs.get('scopes', ["https://datalake.azure.net/.default"])
|
||||
def get_token_internal():
|
||||
# Internal function used so as to use retry decorator
|
||||
if require_2fa or (username is None and password is None and client_secret is None):
|
||||
code = context.acquire_user_code(resource, client_id)
|
||||
print(code['message'])
|
||||
out = context.acquire_token_with_device_code(resource, code, client_id)
|
||||
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
|
||||
flow = contextPub.initiate_device_flow(scopes=scopes)
|
||||
print(flow['message'])
|
||||
out = contextPub.acquire_token_by_device_flow(flow)
|
||||
elif username and password:
|
||||
out = context.acquire_token_with_username_password(resource, username,
|
||||
password, client_id)
|
||||
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
|
||||
out = contextPub.acquire_token_by_username_password(username=username, password=password, scopes=scopes)
|
||||
elif client_id and client_secret:
|
||||
out = context.acquire_token_with_client_credentials(resource, client_id,
|
||||
client_secret)
|
||||
contextClient = msal.ConfidentialClientApplication(client_id=client_id, authority=authority+tenant_id, client_credential=client_secret, http_cache=_http_cache)
|
||||
out = contextClient.acquire_token_for_client(scopes=scopes)
|
||||
# for service principal, we store the secret in the credential object for use when refreshing.
|
||||
out.update({'secret': client_secret})
|
||||
else:
|
||||
raise ValueError("No authentication method found for credentials")
|
||||
return out
|
||||
|
||||
out = get_token_internal()
|
||||
|
||||
out.update({'access': out['accessToken'], 'resource': resource,
|
||||
'refresh': out.get('refreshToken', False),
|
||||
'time': time.time(), 'tenant': tenant_id, 'client': client_id})
|
||||
if 'error' in out:
|
||||
msg = "MSAL Error: "+out.get('error_description', "")
|
||||
err = DatalakeRESTException(msg)
|
||||
logger.log(logging.ERROR, msg)
|
||||
raise err
|
||||
|
||||
out.update({'access_token': out['access_token'], 'access': out['access_token'], 'resource': resource,
|
||||
'refresh': out.get('refresh_token', False),
|
||||
'time': time.time(), 'tenant': tenant_id, 'client': client_id, 'scopes':scopes})
|
||||
|
||||
return DataLakeCredential(out)
|
||||
|
||||
|
||||
class DataLakeCredential:
|
||||
# Be careful modifying this. DataLakeCredential is a general class in azure, and we have to maintain parity.
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
|
||||
def signed_session(self):
|
||||
# type: () -> requests.Session
|
||||
"""Create requests session with any required auth headers applied.
|
||||
|
||||
:rtype: requests.Session
|
||||
"""
|
||||
session = requests.Session()
|
||||
if time.time() - self.token['time'] > self.token['expiresIn'] - 100:
|
||||
if time.time() - self.token['time'] > self.token['expires_in'] - 100:
|
||||
self.refresh_token()
|
||||
|
||||
scheme, token = self.token['tokenType'], self.token['access']
|
||||
session = requests.Session()
|
||||
scheme, token = self.token['token_type'], self.token['access_token']
|
||||
header = "{} {}".format(scheme, token)
|
||||
session.headers['Authorization'] = header
|
||||
return session
|
||||
|
||||
|
||||
def refresh_token(self, authority=None):
|
||||
""" Refresh an expired authorization token
|
||||
|
||||
Parameters
|
||||
----------
|
||||
authority: string
|
||||
|
@ -201,25 +196,31 @@ class DataLakeCredential:
|
|||
if not authority:
|
||||
authority = 'https://login.microsoftonline.com/'
|
||||
|
||||
context = adal.AuthenticationContext(authority +
|
||||
self.token['tenant'])
|
||||
|
||||
tenant_id = self.token['tenant']
|
||||
scopes = self.token['scopes']
|
||||
if self.token.get('secret') and self.token.get('client'):
|
||||
out = context.acquire_token_with_client_credentials(self.token['resource'],
|
||||
self.token['client'],
|
||||
self.token['secret'])
|
||||
out.update({'secret': self.token['secret']})
|
||||
client_id = self.token['client']
|
||||
client_secret = self.token['secret']
|
||||
contextClient = msal.ConfidentialClientApplication(client_id=client_id, authority=authority+tenant_id, client_credential=client_secret, http_cache=_http_cache)
|
||||
out = contextClient.acquire_token_for_client(scopes=scopes)
|
||||
out.update({'secret': client_secret})
|
||||
else:
|
||||
out = context.acquire_token_with_refresh_token(self.token['refresh'],
|
||||
client_id=self.token['client'],
|
||||
resource=self.token['resource'])
|
||||
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
|
||||
out = contextPub.client.obtain_token_by_refresh_token(self.token['refresh'], scopes=scopes)
|
||||
|
||||
if 'error' in out:
|
||||
msg = "MSAL Error: "+out.get('error_description', "")
|
||||
err = DatalakeRESTException(msg)
|
||||
logger.log(logging.ERROR, msg)
|
||||
raise err
|
||||
# common items to update
|
||||
out.update({'access': out['accessToken'],
|
||||
out.update({'access_token': out['access_token'], 'access': out['access_token'],
|
||||
'time': time.time(), 'tenant': self.token['tenant'],
|
||||
'resource': self.token['resource'], 'client': self.token['client']})
|
||||
'resource': self.token['resource'], 'client': self.token['client'], 'scopes':self.token['scopes']})
|
||||
|
||||
self.token = out
|
||||
|
||||
|
||||
class DatalakeRESTInterface:
|
||||
""" Call factory for webHDFS endpoints on ADLS
|
||||
|
||||
|
@ -228,7 +229,7 @@ class DatalakeRESTInterface:
|
|||
store_name: str
|
||||
The name of the Data Lake Store account to execute operations against.
|
||||
token: dict
|
||||
from `auth()` or `refresh_token()` or other ADAL source
|
||||
from `auth()` or `refresh_token()` or other MSAL source
|
||||
url_suffix: str (None)
|
||||
Domain to send REST requests to. The end-point URL is constructed
|
||||
using this and the store_name. If None, use default.
|
||||
|
@ -309,13 +310,10 @@ class DatalakeRESTInterface:
|
|||
return s
|
||||
|
||||
def _check_token(self, retry_policy= None):
|
||||
@retry_decorator_for_auth(retry_policy=retry_policy)
|
||||
def check_token_internal():
|
||||
cur_session = self.token.signed_session()
|
||||
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
|
||||
self.head = {'Authorization': cur_session.headers['Authorization']}
|
||||
self.local.session = None
|
||||
check_token_internal()
|
||||
cur_session = self.token.signed_session()
|
||||
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
|
||||
self.head = {'Authorization': cur_session.headers['Authorization']}
|
||||
self.local.session = None
|
||||
|
||||
def _log_request(self, method, url, op, path, params, headers, retry_count):
|
||||
msg = u"HTTP Request\n{} {}\n".format(method.upper(), url)
|
||||
|
@ -498,7 +496,7 @@ class DatalakeRESTInterface:
|
|||
|
||||
"""
|
||||
Not yet implemented (or not applicable)
|
||||
http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/WebHDFS.html
|
||||
https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/WebHDFS.html
|
||||
|
||||
GETFILECHECKSUM
|
||||
GETHOMEDIRECTORY
|
||||
|
|
|
@ -22,7 +22,7 @@ GLOBAL_EXCEPTION_LOCK = threading.Lock()
|
|||
|
||||
def monitor_exception(exception_queue, process_ids):
|
||||
global GLOBAL_EXCEPTION
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("azure.datalake.store")
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
@ -53,8 +53,8 @@ def log_listener_process(queue):
|
|||
queue.task_done()
|
||||
if record == END_QUEUE_SENTINEL: # We send this as a sentinel to tell the listener to quit.
|
||||
break
|
||||
logger = logging.getLogger(record.name)
|
||||
logger.handlers.clear()
|
||||
logger = logging.getLogger("azure.datalake.store")
|
||||
#logger.handlers.clear()
|
||||
logger.handle(record) # No level or filter logic applied - just do it!
|
||||
except Empty: # Try again
|
||||
pass
|
||||
|
@ -65,7 +65,7 @@ def log_listener_process(queue):
|
|||
|
||||
|
||||
def multi_processor_change_acl(adl, path=None, method_name="", acl_spec="", number_of_sub_process=None):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("azure.datalake.store")
|
||||
|
||||
def launch_processes(number_of_processes):
|
||||
if number_of_processes is None:
|
||||
|
@ -152,8 +152,8 @@ def multi_processor_change_acl(adl, path=None, method_name="", acl_spec="", numb
|
|||
|
||||
|
||||
def processor(adl, file_path_queue, finish_queue_processing_flag, method_name, acl_spec, log_queue, exception_queue):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger("azure.datalake.store")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
removed_default_acl_spec = ",".join([x for x in acl_spec.split(',') if not x.lower().startswith("default")])
|
||||
|
||||
try:
|
||||
|
|
|
@ -74,55 +74,3 @@ class ExponentialRetryPolicy(RetryPolicy):
|
|||
def __backoff(self):
|
||||
time.sleep(self.exponential_retry_interval)
|
||||
self.exponential_retry_interval *= self.exponential_factor
|
||||
|
||||
|
||||
def retry_decorator_for_auth(retry_policy = None):
|
||||
import adal
|
||||
from requests import HTTPError
|
||||
if retry_policy is None:
|
||||
retry_policy = ExponentialRetryPolicy(max_retries=2)
|
||||
|
||||
def deco_retry(func):
|
||||
@wraps(func)
|
||||
def f_retry(*args, **kwargs):
|
||||
retry_count = -1
|
||||
while True:
|
||||
last_exception = None
|
||||
retry_count += 1
|
||||
try:
|
||||
out = func(*args, **kwargs)
|
||||
except (adal.adal_error.AdalError, HTTPError) as e:
|
||||
# ADAL error corresponds to everything but 429, which bubbles up HTTP error.
|
||||
last_exception = e
|
||||
logger.exception("Retry count " + str(retry_count) + "Exception :" + str(last_exception))
|
||||
# We don't want to stop retry for any error in parsing the exception. This is a GET operation.
|
||||
try:
|
||||
if hasattr(last_exception, 'error_response'): # ADAL exception
|
||||
response = response_from_adal_exception(last_exception)
|
||||
if hasattr(last_exception, 'response'): # HTTP exception i.e 429
|
||||
response = last_exception.response
|
||||
except:
|
||||
pass
|
||||
request_successful = last_exception is None or (response is not None and response.status_code == 401) # 401 = Invalid credentials
|
||||
if request_successful or not retry_policy.should_retry(response, last_exception, retry_count):
|
||||
break
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
return out
|
||||
return f_retry
|
||||
|
||||
return deco_retry
|
||||
|
||||
|
||||
def response_from_adal_exception(e):
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
http_code = re.search(r"http error: (\d+)", str(e))
|
||||
if http_code is not None: # Add status_code to response object for use in should_retry
|
||||
status_code = [int(http_code.group(1))]
|
||||
Response = namedtuple("Response", ['status_code'])
|
||||
response = Response(
|
||||
*status_code) # Construct response object with adal exception response and http code
|
||||
return response
|
||||
|
||||
|
|
|
@ -179,7 +179,7 @@ class CountUpDownLatch:
|
|||
self.lock.acquire()
|
||||
self.val -= 1
|
||||
if self.val <= 0:
|
||||
self.lock.notifyAll()
|
||||
self.lock.notify_all()
|
||||
self.lock.release()
|
||||
|
||||
def total_processed(self):
|
||||
|
|
8
setup.py
8
setup.py
|
@ -22,7 +22,7 @@ setup(name='azure-datalake-store',
|
|||
description='Azure Data Lake Store Filesystem Client Library for Python',
|
||||
url='https://github.com/Azure/azure-data-lake-store-python',
|
||||
author='Microsoft Corporation',
|
||||
author_email='ptvshelp@microsoft.com',
|
||||
author_email='Akshat.Harit@microsoft.com',
|
||||
license='MIT License',
|
||||
keywords='azure',
|
||||
classifiers=[
|
||||
|
@ -36,6 +36,10 @@ setup(name='azure-datalake-store',
|
|||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: Python :: 3.11',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
],
|
||||
packages=find_packages(exclude=['tests',
|
||||
|
@ -44,7 +48,7 @@ setup(name='azure-datalake-store',
|
|||
]),
|
||||
install_requires=[
|
||||
'cffi',
|
||||
'adal>=0.4.2',
|
||||
'msal>=1.16.0,<2', # http_cache was introduced in MSAL 1.16.0
|
||||
'requests>=2.20.0',
|
||||
],
|
||||
extras_require={
|
||||
|
|
|
@ -15,7 +15,7 @@ Example: How will the logic behave in case of specific error from server side.
|
|||
|
||||
This was introduced to test the Retry Policy but can be carefully used for other tests as well.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
from requests import ConnectionError, ConnectTimeout, ReadTimeout, Timeout, HTTPError
|
||||
|
||||
|
@ -110,54 +110,3 @@ def __test_retry_error(azure,
|
|||
assert is_exception_expected
|
||||
|
||||
|
||||
@responses.activate
|
||||
def __test_retry_auth(error_code, error_string, is_exception_expected, total_tries=4, last_try_status=200,
|
||||
last_try_body=None):
|
||||
import re, adal
|
||||
end_point_discovery = re.compile("https:\/\/login\.microsoftonline\.com\/common\/discovery\/"
|
||||
"instance\?authorization_endpoint=.+")
|
||||
mock_url_auth = "https://login.microsoftonline.com/" + settings.TENANT_ID + "/oauth2/token"
|
||||
|
||||
body_discovery = r'{"tenant_discovery_endpoint":"https://login.microsoftonline.com/' + TENANT_ID + \
|
||||
'/.well-known/openid-configuration"}'
|
||||
body_error = r'{"error":"' + error_string + r'","error_description":"0","error_codes":[0],"timestamp":"0",' \
|
||||
r'"trace_id":"0","correlation_id":"0"}'
|
||||
if last_try_body is None:
|
||||
last_try_body = r'{"token_type":"Bearer","expires_in":"1","ext_expires_in":"1","expires_on":"1",' \
|
||||
r'"not_before":"1","resource":"https://datalake.azure.net/","access_token":"a"}'
|
||||
|
||||
while total_tries > 0:
|
||||
responses.add(responses.GET, end_point_discovery, body=body_discovery, status=200)
|
||||
responses.add(responses.POST, mock_url_auth, body=body_error, status=error_code)
|
||||
total_tries -= 1
|
||||
|
||||
responses.add(responses.GET, end_point_discovery, body=body_discovery, status=200)
|
||||
responses.add(responses.POST, mock_url_auth, body=last_try_body, status=last_try_status)
|
||||
try:
|
||||
token = auth(tenant_id=TENANT_ID, client_secret='GARBAGE', client_id=CLIENT_ID)
|
||||
assert isinstance(token, DataLakeCredential)
|
||||
assert not is_exception_expected
|
||||
except (HTTPError, adal.adal_error.AdalError):
|
||||
assert is_exception_expected
|
||||
|
||||
|
||||
def test_retry_auth_401():
|
||||
__test_retry_auth(error_code=401, error_string=r'invalid_client', total_tries=1, is_exception_expected=True)
|
||||
|
||||
|
||||
def test_retry_auth_400():
|
||||
__test_retry_auth(error_code=400, error_string=r'invalid_client', total_tries=1, is_exception_expected=False)
|
||||
|
||||
|
||||
def test_retry_auth_104():
|
||||
__test_retry_auth(error_code=104, error_string=r'Connection Error', total_tries=1, is_exception_expected=False )
|
||||
__test_retry_auth(error_code=104, error_string=r'Connection Error', is_exception_expected=True, total_tries=6)
|
||||
|
||||
|
||||
def test_retry_auth_429():
|
||||
__test_retry_auth(error_code=429, error_string=r'Too many requests', total_tries=2, is_exception_expected=False)
|
||||
__test_retry_auth(error_code=429, error_string=r'Too many requests', is_exception_expected=True, total_tries=6)
|
||||
|
||||
|
||||
def test_retry_auth_501():
|
||||
__test_retry_auth(error_code=501, error_string=r'invalid_client', total_tries=1, is_exception_expected=False)
|
||||
|
|
|
@ -42,7 +42,9 @@ def linecount(infile):
|
|||
|
||||
|
||||
@contextmanager
|
||||
def setup_tree(azure):
|
||||
def setup_tree(azure, test_dir=None):
|
||||
if test_dir == None:
|
||||
test_dir = working_dir()
|
||||
for directory in ['', 'data/a', 'data/b']:
|
||||
azure.mkdir(test_dir / directory)
|
||||
for filename in ['x.csv', 'y.csv', 'z.txt']:
|
||||
|
@ -55,9 +57,7 @@ def setup_tree(azure):
|
|||
try:
|
||||
yield
|
||||
finally:
|
||||
for path in azure.ls(test_dir, invalidate_cache=False):
|
||||
if azure.exists(path, invalidate_cache=False):
|
||||
azure.rm(path, recursive=True)
|
||||
azure.rm(test_dir, recursive=True)
|
||||
|
||||
|
||||
def create_remote_csv(fs, name, columns, colwidth, lines):
|
||||
|
@ -246,7 +246,7 @@ def test_download_overwrite(tempdir, azure):
|
|||
|
||||
with pytest.raises(OSError) as e:
|
||||
ADLDownloader(azure, test_dir, tempdir, 1, 2**24, run=False)
|
||||
assert tempdir in str(e)
|
||||
assert os.path.split(tempdir)[1] in str(e)
|
||||
|
||||
|
||||
@my_vcr.use_cassette
|
||||
|
@ -506,15 +506,13 @@ def test_set_acl_recusrive(azure):
|
|||
def check_acl_perms(path, permission):
|
||||
current_acl = azure.get_acl_status(path)
|
||||
acl_user_entry = [s for s in current_acl['entries'] if acluser in s]
|
||||
assert len(acl_user_entry) == 1
|
||||
assert len(acl_user_entry) == 1, "Path: "+path + " Acls: " + str(acl_user_entry)
|
||||
assert acl_user_entry[0].split(':')[-1] == permission
|
||||
|
||||
files = list(azure.walk(test_dir))
|
||||
directories = list(set([x[0] for x in map(os.path.split, files)]))
|
||||
|
||||
permission = "rwx"
|
||||
azure.set_acl(test_dir, acl_spec=set_acl_base + "user:"+acluser+":"+permission, recursive=True, number_of_sub_process=2)
|
||||
|
||||
for path in files+directories:
|
||||
check_acl_perms(path, permission)
|
||||
|
||||
|
|
|
@ -113,11 +113,13 @@ def create_files(azure, number_of_files, root_path = working_dir(), prefix=''):
|
|||
@contextmanager
|
||||
def azure_teardown(fs):
|
||||
try:
|
||||
fs.mkdir(working_dir())
|
||||
yield
|
||||
finally:
|
||||
# this is a best effort. If there is an error attempting to delete during cleanup,
|
||||
# print it, but it should not cause the test to fail.
|
||||
try:
|
||||
fs.rm(working_dir(), recursive=True)
|
||||
for path in fs.ls(working_dir(), invalidate_cache=False):
|
||||
if fs.exists(path, invalidate_cache=False):
|
||||
fs.rm(path, recursive=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче