Resolving issues from PR review.

This commit is contained in:
Nate McFeters 2023-01-02 15:07:33 -06:00
Родитель e7d4201977
Коммит d1b10dab62
4 изменённых файлов: 128 добавлений и 65 удалений

Просмотреть файл

@ -1,30 +1,74 @@
import os
import pathlib
import json
from typing import Optional, Union
import json5
import sys
import requests
import msal
import logging
class AuthError(RuntimeError):
"""Handle Auth Errors when attempting device_flow or recovery from cached token.
:param result: The result of the failed attempt
"""
result: Union[dict[str, Optional[int]], dict, dict[str, str], dict[str, Union[int, dict[str, str], str]], None]
def __init__(self, result: Union[
dict[str, Optional[int]],
dict,
dict[str, str],
dict[str, Union[int, dict[str, str], str]],
None
]
) -> None:
super().__init__()
self.result = result
def __str__(self) -> str:
"""Generate and return the string representation of the object.
:return: A string representation of the object
"""
if not self.result.get("error_message"):
return f"AuthError: Fatal error in authentication - {self.result.__str__()}"
else:
return f"AuthError: Fatal error in authentication - {self.result.get('error_message')}"
class AuthHelper:
tenantId = os.environ["tenant_id"]
scope = os.environ["scope"]
authority = os.environ["authority"]
appId = os.environ["appId"]
accessToken = ""
def __init__(self,**kwargs):
for key,value in kwargs.items():
setattr(self,key,value)
def __init__(self, scope: Optional[str], app_id: Optional[str], log: Optional[logging.Logger]):
if not scope:
self.scope = os.environ["scope"]
else:
self.scope = scope
if not app_id:
self.app_id = os.environ["appId"]
else:
self.app_id = app_id
if not log:
self.log = logging.getLogger("ado")
self.log.getChild("AuthHelper")
else:
self.log = log
self.log.getChild("AuthHelper")
def adoAuthenticate(self):
def device_flow_auth(self) -> \
Union[
dict[str, Optional[int]],
dict,
dict[str, str],
dict[str, Union[int, dict[str, str], str]],
None
]:
mscache = msal.SerializableTokenCache()
output = pathlib.Path(__file__).parent / pathlib.Path("token.bin")
output = pathlib.Path.home() / pathlib.Path("tokens") / pathlib.Path("token.bin")
if os.path.exists(output):
print("Deserializing cached credentials.")
mscache.deserialize(open(output, "r").read())
self.log.debug("Deserializing cached credentials.")
mscache.deserialize(open(output, "r", encoding="utf-8").read())
app = msal.PublicClientApplication(
client_id=self.appId,
client_id=self.app_id,
token_cache=mscache
)
accounts = app.get_accounts()
@ -34,35 +78,37 @@ class AuthHelper:
account=accounts[0]
)
if mscache.has_state_changed:
with open(output, "w") as cache_file:
print("Caching credentials.")
with open(output, "w", encoding='utf-8') as cache_file:
self.log.debug("Caching credentials.")
cache_file.write(mscache.serialize())
cache_file.close()
if result is not None:
if "access_token" in result:
print("Found access token.")
self.log.debug("Found access token.")
return result
else:
raise RuntimeError(result.get("error_description"))
print("Initiating device flow.")
raise AuthError(result=result)
self.log.debug("Initiating device flow.")
flow = app.initiate_device_flow(scopes=[self.scope])
if "user_code" not in flow:
raise ValueError("User code not if result. Error: %s" % json5.dumps(flow, indent=5))
print(flow["message"])
raise ValueError("User code not in result. Error: %s" % json.dumps(flow, indent=5))
print(flow["message"]) # Think this one needs to stay as a print so user sees the prompt
result = app.acquire_token_by_device_flow(flow)
if "access_token" in result:
print("Access token acquired.")
if not os.path.exists(output):
with open(output, "w") as cache_file:
print("Writing cached credentials.")
self.log.debug("Access token acquired.")
if not pathlib.Path.is_dir(output.parent):
pathlib.Path.mkdir(output.parent)
if not pathlib.Path.exists(output):
with open(output, "w", encoding='utf-8') as cache_file:
self.log.debug("Writing cached credentials.")
cache_file.write(mscache.serialize())
cache_file.close()
else:
with open(output,"w+") as cache_file:
print("Writing cached credentials.")
with open(output, "w+", encoding='utf-8') as cache_file:
self.log.debug("Writing cached credentials.")
cache_file.write(mscache.serialize())
cache_file.close()
return result
else:
print("No access token found.")
raise RuntimeError(result.get("error_description"))
return ""
self.log.debug("No access token found.")
raise AuthError(result)

Просмотреть файл

@ -1,58 +1,58 @@
import os
import pathlib
import sys
import simple_ado
import argparse
import logging
from auth_helper import AuthHelper
from auth_helper import AuthHelper, AuthError
from simple_ado import ADOException
from simple_ado.exceptions import ADOHTTPException
def main():
logger = logging.getLogger("test")
logging.basicConfig(level=logging.DEBUG,handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger("ado.device_flow_test")
app_id = os.environ["appId"]
scope = os.environ["scope"]
project_id = os.environ["SIMPLE_ADO_PROJECT_ID"]
repo_id = os.environ["SIMPLE_ADO_REPO_ID"]
token = os.environ["SIMPLE_ADO_BASE_TOKEN"]
username = os.environ["SIMPLE_ADO_USERNAME"]
tenant = os.environ["SIMPLE_ADO_TENANT"]
output = pathlib.Path(__file__).parent / pathlib.Path(outputDir) / pathlib.Path(repo_id + ".zip")
output = pathlib.Path.home() / pathlib.Path(outputDir) / pathlib.Path(repo_id + ".zip")
if not pathlib.Path.is_dir(pathlib.Path(output).parent):
pathlib.Path.mkdir(pathlib.Path(output).parent)
print("* Fetching the repo: " + repoUrlStr)
logger.debug("* Fetching the repo: " + repoUrlStr)
try:
ah = AuthHelper()
token = ah.adoAuthenticate()
ah = AuthHelper(scope=scope,app_id=app_id,log=logger)
token = ah.device_flow_auth()
print("** Setting up ADOHTTPClient with " + tenant)
logger.debug("** Setting up ADOHTTPClient with " + tenant)
http = simple_ado.http_client.ADOHTTPClient(tenant=tenant,
credentials=token,
user_agent="test",
log = logger
)
git_client = simple_ado.git.ADOGitClient(http_client=http, log=logger)
print("** Getting Repository: " + repo_id + " from " + project_id)
logger.debug("** Getting Repository: " + repo_id + " from " + project_id)
repo = git_client.get_repository(repository_id=repo_id, project_id=project_id)
branch = repo["defaultBranch"]
branch = branch.split("/")[-1]
#callback=None
#if progress:
# callback=handle_progress
#zip = git_client.download_zip(output_path=output, repository_id=repo["id"], branch=branch, project_id=project_id,callback=callback)
zip = git_client.download_zip(output_path=output, repository_id=repo["id"], branch=branch, project_id=project_id)
logger.debug("Completed.")
except ADOHTTPException as e:
print("ADOHTTPException " + str(e.response.status_code) + " on: ")
print("e.message = " + e.message)
print("e.response = " + str(e.response.content))
print("e.request.url = " + e.response.request.url + " path: " + e.response.request.path_url)
logger.critical("ADOHTTPException " + str(e.response.status_code) + " on: ")
logger.critical("e.message = " + e.message)
logger.critical("e.response = " + str(e.response.content))
logger.critical("e.request.url = " + e.response.request.url + " path: " + e.response.request.path_url)
if e.response.request.body:
print("e.request.body = " + e.response.request.body)
logger.critical("e.request.body = " + e.response.request.body)
except ADOException as e:
if "The output path already exists" in str(e):
print("Skipping for " + repoUrlStr + " it already exists.")
logger.debug("Skipping for " + repoUrlStr + " it already exists.")
pass
else:
print("ADOException " + str(e) + " on: ")
logger.critical("ADOException " + str(e) + " on: ")
except AuthError as e:
logger.debug(str(e))
if __name__=="__main__":
parser = argparse.ArgumentParser(

Просмотреть файл

@ -8,7 +8,7 @@
import enum
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Callable
import urllib.parse
from simple_ado.base_client import ADOBaseClient

Просмотреть файл

@ -17,7 +17,6 @@ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_
from simple_ado.exceptions import ADOException, ADOHTTPException
from simple_ado.models import PatchOperation
# pylint: disable=invalid-name
ADOThread = Dict[str, Any]
ADOResponse = Any
@ -59,7 +58,7 @@ class ADOHTTPClient:
log: logging.Logger
tenant: str
extra_headers: Dict[str, str]
credentials: Tuple[str, str]
credentials = None # Modify away from tuple to expand for device_flow
_not_before: Optional[datetime.datetime]
_session: requests.Session
@ -67,7 +66,7 @@ class ADOHTTPClient:
self,
*,
tenant: str,
credentials: Tuple[str, str],
credentials: None, # Modify away from tuple to support device_flow
user_agent: str,
log: logging.Logger,
extra_headers: Optional[Dict[str, str]] = None,
@ -78,6 +77,7 @@ class ADOHTTPClient:
self.tenant = tenant
self.credentials = credentials
self._not_before = None
self._session = requests.Session()
@ -200,12 +200,29 @@ class ADOHTTPClient:
:returns: The raw response object from the API
"""
self._wait()
headers = self.construct_headers(additional_headers=additional_headers)
response = self._session.get(
request_url, auth=self.credentials, headers=headers, stream=stream
)
response = None
# Modified to support PAT-based Authentication
if isinstance(self.credentials, tuple) and len(self.credentials) == 2:
headers = self.construct_headers(additional_headers=additional_headers)
response = self._session.get(
request_url, auth=self.credentials, headers=headers, stream=stream
)
# Since requests does not support BearerAuth, this is a hacky way to do it.
elif isinstance(self.credentials, dict) and len(self.credentials) == 3:
if additional_headers:
additional_headers['Authorization'] = 'Bearer ' + self.credentials["access_token"]
else:
additional_headers = {"Authorization":"Bearer " + self.credentials["access_token"]}
headers = self.construct_headers(additional_headers=additional_headers)
response = self._session.get(
request_url, auth=None, headers=headers, stream=stream
)
else:
self.log.critical("len(self.credentials) == " + str(len(self.credentials)))
self.log.critical("type(self.credentials) == " + str(type(self.credentials)))
raise ValueError("Unknown authentication type. Modify simple_ado/http_client.py to support.")
if not response:
raise ADOHTTPException("Unable to get a response from authentication phase. See simple_ado/http_client.py.")
self._track_rate_limit(response)
if response.status_code == 429:
@ -448,4 +465,4 @@ class ADOHTTPClient:
for header_name, header_value in additional_headers.items():
headers[header_name] = header_value
return headers
return headers