Resolving issues from PR review.
This commit is contained in:
Родитель
e7d4201977
Коммит
d1b10dab62
110
auth_helper.py
110
auth_helper.py
|
@ -1,30 +1,74 @@
|
|||
import os
|
||||
import pathlib
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
import json5
|
||||
import sys
|
||||
import requests
|
||||
import msal
|
||||
import logging
|
||||
|
||||
|
||||
class AuthError(RuntimeError):
|
||||
|
||||
"""Handle Auth Errors when attempting device_flow or recovery from cached token.
|
||||
:param result: The result of the failed attempt
|
||||
"""
|
||||
|
||||
result: Union[dict[str, Optional[int]], dict, dict[str, str], dict[str, Union[int, dict[str, str], str]], None]
|
||||
|
||||
def __init__(self, result: Union[
|
||||
dict[str, Optional[int]],
|
||||
dict,
|
||||
dict[str, str],
|
||||
dict[str, Union[int, dict[str, str], str]],
|
||||
None
|
||||
]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.result = result
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Generate and return the string representation of the object.
|
||||
:return: A string representation of the object
|
||||
"""
|
||||
if not self.result.get("error_message"):
|
||||
return f"AuthError: Fatal error in authentication - {self.result.__str__()}"
|
||||
else:
|
||||
return f"AuthError: Fatal error in authentication - {self.result.get('error_message')}"
|
||||
|
||||
|
||||
class AuthHelper:
|
||||
tenantId = os.environ["tenant_id"]
|
||||
scope = os.environ["scope"]
|
||||
authority = os.environ["authority"]
|
||||
appId = os.environ["appId"]
|
||||
accessToken = ""
|
||||
def __init__(self,**kwargs):
|
||||
for key,value in kwargs.items():
|
||||
setattr(self,key,value)
|
||||
def __init__(self, scope: Optional[str], app_id: Optional[str], log: Optional[logging.Logger]):
|
||||
if not scope:
|
||||
self.scope = os.environ["scope"]
|
||||
else:
|
||||
self.scope = scope
|
||||
if not app_id:
|
||||
self.app_id = os.environ["appId"]
|
||||
else:
|
||||
self.app_id = app_id
|
||||
if not log:
|
||||
self.log = logging.getLogger("ado")
|
||||
self.log.getChild("AuthHelper")
|
||||
else:
|
||||
self.log = log
|
||||
self.log.getChild("AuthHelper")
|
||||
|
||||
def adoAuthenticate(self):
|
||||
def device_flow_auth(self) -> \
|
||||
Union[
|
||||
dict[str, Optional[int]],
|
||||
dict,
|
||||
dict[str, str],
|
||||
dict[str, Union[int, dict[str, str], str]],
|
||||
None
|
||||
]:
|
||||
mscache = msal.SerializableTokenCache()
|
||||
output = pathlib.Path(__file__).parent / pathlib.Path("token.bin")
|
||||
output = pathlib.Path.home() / pathlib.Path("tokens") / pathlib.Path("token.bin")
|
||||
if os.path.exists(output):
|
||||
print("Deserializing cached credentials.")
|
||||
mscache.deserialize(open(output, "r").read())
|
||||
self.log.debug("Deserializing cached credentials.")
|
||||
mscache.deserialize(open(output, "r", encoding="utf-8").read())
|
||||
|
||||
app = msal.PublicClientApplication(
|
||||
client_id=self.appId,
|
||||
client_id=self.app_id,
|
||||
token_cache=mscache
|
||||
)
|
||||
accounts = app.get_accounts()
|
||||
|
@ -34,35 +78,37 @@ class AuthHelper:
|
|||
account=accounts[0]
|
||||
)
|
||||
if mscache.has_state_changed:
|
||||
with open(output, "w") as cache_file:
|
||||
print("Caching credentials.")
|
||||
with open(output, "w", encoding='utf-8') as cache_file:
|
||||
self.log.debug("Caching credentials.")
|
||||
cache_file.write(mscache.serialize())
|
||||
cache_file.close()
|
||||
if result is not None:
|
||||
if "access_token" in result:
|
||||
print("Found access token.")
|
||||
self.log.debug("Found access token.")
|
||||
return result
|
||||
else:
|
||||
raise RuntimeError(result.get("error_description"))
|
||||
print("Initiating device flow.")
|
||||
raise AuthError(result=result)
|
||||
self.log.debug("Initiating device flow.")
|
||||
flow = app.initiate_device_flow(scopes=[self.scope])
|
||||
if "user_code" not in flow:
|
||||
raise ValueError("User code not if result. Error: %s" % json5.dumps(flow, indent=5))
|
||||
print(flow["message"])
|
||||
raise ValueError("User code not in result. Error: %s" % json.dumps(flow, indent=5))
|
||||
print(flow["message"]) # Think this one needs to stay as a print so user sees the prompt
|
||||
result = app.acquire_token_by_device_flow(flow)
|
||||
if "access_token" in result:
|
||||
print("Access token acquired.")
|
||||
if not os.path.exists(output):
|
||||
with open(output, "w") as cache_file:
|
||||
print("Writing cached credentials.")
|
||||
self.log.debug("Access token acquired.")
|
||||
if not pathlib.Path.is_dir(output.parent):
|
||||
pathlib.Path.mkdir(output.parent)
|
||||
if not pathlib.Path.exists(output):
|
||||
with open(output, "w", encoding='utf-8') as cache_file:
|
||||
self.log.debug("Writing cached credentials.")
|
||||
cache_file.write(mscache.serialize())
|
||||
cache_file.close()
|
||||
else:
|
||||
with open(output,"w+") as cache_file:
|
||||
print("Writing cached credentials.")
|
||||
with open(output, "w+", encoding='utf-8') as cache_file:
|
||||
self.log.debug("Writing cached credentials.")
|
||||
cache_file.write(mscache.serialize())
|
||||
cache_file.close()
|
||||
return result
|
||||
else:
|
||||
print("No access token found.")
|
||||
raise RuntimeError(result.get("error_description"))
|
||||
return ""
|
||||
self.log.debug("No access token found.")
|
||||
raise AuthError(result)
|
||||
|
|
|
@ -1,58 +1,58 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
import sys
|
||||
import simple_ado
|
||||
import argparse
|
||||
import logging
|
||||
from auth_helper import AuthHelper
|
||||
from auth_helper import AuthHelper, AuthError
|
||||
from simple_ado import ADOException
|
||||
from simple_ado.exceptions import ADOHTTPException
|
||||
|
||||
def main():
|
||||
logger = logging.getLogger("test")
|
||||
logging.basicConfig(level=logging.DEBUG,handlers=[logging.StreamHandler(sys.stdout)])
|
||||
logger = logging.getLogger("ado.device_flow_test")
|
||||
app_id = os.environ["appId"]
|
||||
scope = os.environ["scope"]
|
||||
project_id = os.environ["SIMPLE_ADO_PROJECT_ID"]
|
||||
repo_id = os.environ["SIMPLE_ADO_REPO_ID"]
|
||||
token = os.environ["SIMPLE_ADO_BASE_TOKEN"]
|
||||
username = os.environ["SIMPLE_ADO_USERNAME"]
|
||||
tenant = os.environ["SIMPLE_ADO_TENANT"]
|
||||
output = pathlib.Path(__file__).parent / pathlib.Path(outputDir) / pathlib.Path(repo_id + ".zip")
|
||||
output = pathlib.Path.home() / pathlib.Path(outputDir) / pathlib.Path(repo_id + ".zip")
|
||||
if not pathlib.Path.is_dir(pathlib.Path(output).parent):
|
||||
pathlib.Path.mkdir(pathlib.Path(output).parent)
|
||||
print("* Fetching the repo: " + repoUrlStr)
|
||||
logger.debug("* Fetching the repo: " + repoUrlStr)
|
||||
|
||||
try:
|
||||
ah = AuthHelper()
|
||||
token = ah.adoAuthenticate()
|
||||
ah = AuthHelper(scope=scope,app_id=app_id,log=logger)
|
||||
token = ah.device_flow_auth()
|
||||
|
||||
print("** Setting up ADOHTTPClient with " + tenant)
|
||||
logger.debug("** Setting up ADOHTTPClient with " + tenant)
|
||||
http = simple_ado.http_client.ADOHTTPClient(tenant=tenant,
|
||||
credentials=token,
|
||||
user_agent="test",
|
||||
log = logger
|
||||
)
|
||||
git_client = simple_ado.git.ADOGitClient(http_client=http, log=logger)
|
||||
print("** Getting Repository: " + repo_id + " from " + project_id)
|
||||
logger.debug("** Getting Repository: " + repo_id + " from " + project_id)
|
||||
repo = git_client.get_repository(repository_id=repo_id, project_id=project_id)
|
||||
branch = repo["defaultBranch"]
|
||||
branch = branch.split("/")[-1]
|
||||
#callback=None
|
||||
#if progress:
|
||||
# callback=handle_progress
|
||||
#zip = git_client.download_zip(output_path=output, repository_id=repo["id"], branch=branch, project_id=project_id,callback=callback)
|
||||
zip = git_client.download_zip(output_path=output, repository_id=repo["id"], branch=branch, project_id=project_id)
|
||||
logger.debug("Completed.")
|
||||
except ADOHTTPException as e:
|
||||
print("ADOHTTPException " + str(e.response.status_code) + " on: ")
|
||||
print("e.message = " + e.message)
|
||||
print("e.response = " + str(e.response.content))
|
||||
print("e.request.url = " + e.response.request.url + " path: " + e.response.request.path_url)
|
||||
logger.critical("ADOHTTPException " + str(e.response.status_code) + " on: ")
|
||||
logger.critical("e.message = " + e.message)
|
||||
logger.critical("e.response = " + str(e.response.content))
|
||||
logger.critical("e.request.url = " + e.response.request.url + " path: " + e.response.request.path_url)
|
||||
if e.response.request.body:
|
||||
print("e.request.body = " + e.response.request.body)
|
||||
logger.critical("e.request.body = " + e.response.request.body)
|
||||
except ADOException as e:
|
||||
if "The output path already exists" in str(e):
|
||||
print("Skipping for " + repoUrlStr + " it already exists.")
|
||||
logger.debug("Skipping for " + repoUrlStr + " it already exists.")
|
||||
pass
|
||||
else:
|
||||
print("ADOException " + str(e) + " on: ")
|
||||
logger.critical("ADOException " + str(e) + " on: ")
|
||||
except AuthError as e:
|
||||
logger.debug(str(e))
|
||||
|
||||
if __name__=="__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import enum
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
import urllib.parse
|
||||
|
||||
from simple_ado.base_client import ADOBaseClient
|
||||
|
|
|
@ -17,7 +17,6 @@ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_
|
|||
from simple_ado.exceptions import ADOException, ADOHTTPException
|
||||
from simple_ado.models import PatchOperation
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
ADOThread = Dict[str, Any]
|
||||
ADOResponse = Any
|
||||
|
@ -59,7 +58,7 @@ class ADOHTTPClient:
|
|||
log: logging.Logger
|
||||
tenant: str
|
||||
extra_headers: Dict[str, str]
|
||||
credentials: Tuple[str, str]
|
||||
credentials = None # Modify away from tuple to expand for device_flow
|
||||
_not_before: Optional[datetime.datetime]
|
||||
_session: requests.Session
|
||||
|
||||
|
@ -67,7 +66,7 @@ class ADOHTTPClient:
|
|||
self,
|
||||
*,
|
||||
tenant: str,
|
||||
credentials: Tuple[str, str],
|
||||
credentials: None, # Modify away from tuple to support device_flow
|
||||
user_agent: str,
|
||||
log: logging.Logger,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
|
@ -78,6 +77,7 @@ class ADOHTTPClient:
|
|||
|
||||
self.tenant = tenant
|
||||
self.credentials = credentials
|
||||
|
||||
self._not_before = None
|
||||
|
||||
self._session = requests.Session()
|
||||
|
@ -200,12 +200,29 @@ class ADOHTTPClient:
|
|||
:returns: The raw response object from the API
|
||||
"""
|
||||
self._wait()
|
||||
|
||||
headers = self.construct_headers(additional_headers=additional_headers)
|
||||
response = self._session.get(
|
||||
request_url, auth=self.credentials, headers=headers, stream=stream
|
||||
)
|
||||
|
||||
response = None
|
||||
# Modified to support PAT-based Authentication
|
||||
if isinstance(self.credentials, tuple) and len(self.credentials) == 2:
|
||||
headers = self.construct_headers(additional_headers=additional_headers)
|
||||
response = self._session.get(
|
||||
request_url, auth=self.credentials, headers=headers, stream=stream
|
||||
)
|
||||
# Since requests does not support BearerAuth, this is a hacky way to do it.
|
||||
elif isinstance(self.credentials, dict) and len(self.credentials) == 3:
|
||||
if additional_headers:
|
||||
additional_headers['Authorization'] = 'Bearer ' + self.credentials["access_token"]
|
||||
else:
|
||||
additional_headers = {"Authorization":"Bearer " + self.credentials["access_token"]}
|
||||
headers = self.construct_headers(additional_headers=additional_headers)
|
||||
response = self._session.get(
|
||||
request_url, auth=None, headers=headers, stream=stream
|
||||
)
|
||||
else:
|
||||
self.log.critical("len(self.credentials) == " + str(len(self.credentials)))
|
||||
self.log.critical("type(self.credentials) == " + str(type(self.credentials)))
|
||||
raise ValueError("Unknown authentication type. Modify simple_ado/http_client.py to support.")
|
||||
if not response:
|
||||
raise ADOHTTPException("Unable to get a response from authentication phase. See simple_ado/http_client.py.")
|
||||
self._track_rate_limit(response)
|
||||
|
||||
if response.status_code == 429:
|
||||
|
@ -448,4 +465,4 @@ class ADOHTTPClient:
|
|||
for header_name, header_value in additional_headers.items():
|
||||
headers[header_name] = header_value
|
||||
|
||||
return headers
|
||||
return headers
|
Загрузка…
Ссылка в новой задаче