Add auto access token refresh
This commit is contained in:
Родитель
23701a3481
Коммит
d710d3bd46
|
@ -1,25 +1,19 @@
|
|||
from urllib.parse import urlencode
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
from a01.auth import AuthSettings
|
||||
from a01.common import A01Config
|
||||
from a01.models import Run, RunsView
|
||||
from a01.transport import AsyncSession
|
||||
|
||||
|
||||
async def query_run_async(run_id: str) -> Run:
|
||||
endpoint = A01Config().ensure_config().endpoint_uri
|
||||
async with aiohttp.ClientSession(headers={'Authorization': AuthSettings().access_token}) as session:
|
||||
async with session.get(f'{endpoint}/run/{run_id}') as resp:
|
||||
json_body = await resp.json()
|
||||
return Run.from_dict(json_body)
|
||||
async with AsyncSession() as session:
|
||||
return Run.from_dict(session.get_json(f'run/{run_id}'))
|
||||
|
||||
|
||||
async def query_runs_async(**kwargs) -> RunsView:
|
||||
endpoint = A01Config().ensure_config().endpoint_uri
|
||||
async with aiohttp.ClientSession(headers={'Authorization': AuthSettings().access_token}) as session:
|
||||
url = f'{endpoint}/runs'
|
||||
async with AsyncSession() as session:
|
||||
url = 'runs'
|
||||
query = {}
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
|
@ -28,9 +22,8 @@ async def query_runs_async(**kwargs) -> RunsView:
|
|||
if query:
|
||||
url = f'{url}?{urlencode(query)}'
|
||||
|
||||
async with session.get(url) as resp:
|
||||
json_body = await resp.json()
|
||||
return RunsView(runs=[Run.from_dict(each) for each in json_body])
|
||||
json_body = await session.get_json(url)
|
||||
return RunsView(runs=[Run.from_dict(each) for each in json_body])
|
||||
|
||||
|
||||
def query_run(run_id: str) -> Run:
|
||||
|
|
|
@ -1,31 +1,23 @@
|
|||
from typing import List
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
from a01.auth import AuthSettings
|
||||
from a01.common import A01Config
|
||||
from a01.models import Task
|
||||
from a01.transport import AsyncSession
|
||||
|
||||
|
||||
async def query_tasks_async(ids: List[str]) -> List[Task]:
|
||||
results = []
|
||||
endpoint = A01Config().ensure_config().endpoint_uri
|
||||
async with aiohttp.ClientSession(headers={'Authorization': AuthSettings().access_token}) as session:
|
||||
async with AsyncSession() as session:
|
||||
for task_id in ids:
|
||||
async with session.get(f'{endpoint}/task/{task_id}') as resp:
|
||||
json_body = await resp.json()
|
||||
results.append(Task.from_dict(json_body))
|
||||
results.append(Task.from_dict(await session.get_json(f'task/{task_id}')))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def query_tasks_by_run_async(run_id: str) -> List[Task]:
|
||||
endpoint = A01Config().ensure_config().endpoint_uri
|
||||
async with aiohttp.ClientSession(headers={'Authorization': AuthSettings().access_token}) as session:
|
||||
async with session.get(f'{endpoint}/run/{run_id}/tasks') as resp:
|
||||
json_body = await resp.json()
|
||||
return [Task.from_dict(data) for data in json_body]
|
||||
async with AsyncSession() as session:
|
||||
return [Task.from_dict(each) for each in await session.get_json(f'run/{run_id}/tasks')]
|
||||
|
||||
|
||||
def query_tasks(ids: List[str]) -> List[Task]:
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import sys
|
||||
from logging import getLogger
|
||||
from typing import Union, List
|
||||
|
||||
from aiohttp import ClientSession, ContentTypeError
|
||||
|
||||
from a01.auth import AuthSettings
|
||||
from a01.common import A01Config
|
||||
|
||||
|
||||
class AsyncSession(ClientSession):
|
||||
def __init__(self) -> None:
|
||||
super(AsyncSession, self).__init__()
|
||||
self.auth = AuthSettings()
|
||||
self.endpoint = A01Config().ensure_config().endpoint_uri
|
||||
self.logger = getLogger(__name__)
|
||||
|
||||
def get_path(self, path: str) -> str:
|
||||
return f'{self.endpoint}/{path}'
|
||||
|
||||
async def get_json(self, path: str) -> Union[List, dict, float, str, None]:
|
||||
if self.auth.is_expired and not self.auth.refresh():
|
||||
self.logger.error('Fail to refresh access token. Please login again.')
|
||||
sys.exit(1)
|
||||
|
||||
headers = {'Authorization': self.auth.access_token}
|
||||
|
||||
async with self.get(self.get_path(path), headers=headers) as resp:
|
||||
try:
|
||||
return await resp.json()
|
||||
except ContentTypeError:
|
||||
self.logger.error('Incorrect content type')
|
||||
self.logger.error(await resp.text())
|
||||
raise
|
Загрузка…
Ссылка в новой задаче