977 строки
35 KiB
Python
Executable File
977 строки
35 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Microsoft Corporation
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# MIT License
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
# copy of this software and associated documentation files (the "Software"),
|
|
# to deal in the Software without restriction, including without limitation
|
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
# and/or sell copies of the Software, and to permit persons to whom the
|
|
# Software is furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in
|
|
# all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
# DEALINGS IN THE SOFTWARE.
|
|
|
|
# stdlib imports
|
|
import argparse
|
|
import asyncio
|
|
import base64
|
|
import datetime
|
|
import enum
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import logging.handlers
|
|
import os
|
|
import pathlib
|
|
try:
|
|
import pwd
|
|
except ImportError:
|
|
pass
|
|
import queue
|
|
import random
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import Tuple
|
|
# non-stdlib imports
|
|
import azure.common
|
|
import azure.cosmosdb.table as azuretable
|
|
import azure.storage.blob as azureblob
|
|
|
|
logger = None
|
|
# global defines
|
|
_ON_WINDOWS = sys.platform == 'win32'
|
|
_CONTAINER_MODE = None
|
|
_DOCKER_CONFIG_FILE = os.path.join(
|
|
os.environ['AZ_BATCH_TASK_WORKING_DIR'], '.docker', 'config.json')
|
|
_DOCKER_TAG = 'docker:'
|
|
_SINGULARITY_TAG = 'singularity:'
|
|
_NODEID = os.environ['AZ_BATCH_NODE_ID']
|
|
_NODE_ROOT_DIR = os.environ['AZ_BATCH_NODE_ROOT_DIR']
|
|
try:
|
|
_SINGULARITY_CACHE_DIR = pathlib.Path(os.environ['SINGULARITY_CACHEDIR'])
|
|
except KeyError:
|
|
_SINGULARITY_CACHE_DIR = None
|
|
try:
|
|
_SINGULARITY_SYPGP_DIR = pathlib.Path(os.environ['SINGULARITY_SYPGPDIR'])
|
|
except KeyError:
|
|
_SINGULARITY_SYPGP_DIR = None
|
|
try:
|
|
_AZBATCH_USER = pwd.getpwnam('_azbatch')
|
|
except NameError:
|
|
_AZBATCH_USER = None
|
|
_PARTITION_KEY = None
|
|
_MAX_VMLIST_PROPERTIES = 13
|
|
_MAX_VMLIST_IDS_PER_PROPERTY = 800
|
|
_DOCKER_AUTH_MAP = None
|
|
_DOCKER_AUTH_MAP_LOCK = threading.Lock()
|
|
_DIRECTDL_LOCK = threading.Lock()
|
|
_CONCURRENT_DOWNLOADS_ALLOWED = 10
|
|
_RECORD_PERF = int(os.getenv('SHIPYARD_TIMING', default='0'))
|
|
# mutable global state
|
|
_CBHANDLES = {}
|
|
_BLOB_LEASES = {}
|
|
_PREFIX = None
|
|
_STORAGE_CONTAINERS = {
|
|
'blob_globalresources': None,
|
|
'table_images': None,
|
|
'table_globalresources': None,
|
|
}
|
|
_DIRECTDL_QUEUE = queue.Queue()
|
|
_DIRECTDL_KEY_FINGERPRINT_DICT = dict()
|
|
_DIRECTDL_DOWNLOADING = set()
|
|
_GR_DONE = False
|
|
_THREAD_EXCEPTIONS = []
|
|
_DOCKER_PULL_ERRORS = frozenset((
|
|
'toomanyrequests',
|
|
'connection reset by peer',
|
|
'error pulling image configuration',
|
|
'error parsing http 404 response body',
|
|
'received unexpected http status',
|
|
'tls handshake timeout',
|
|
))
|
|
|
|
|
|
class ContainerMode(enum.Enum):
|
|
DOCKER = 1
|
|
SINGULARITY = 2
|
|
|
|
|
|
class StandardStreamLogger:
|
|
"""Standard Stream Logger"""
|
|
def __init__(self, level):
|
|
"""Standard Stream ctor"""
|
|
self.level = level
|
|
|
|
def write(self, message: str) -> None:
|
|
"""Write a message to the stream
|
|
:param str message: message to write
|
|
"""
|
|
if message != '\n':
|
|
self.level(message)
|
|
|
|
def flush(self) -> None:
|
|
"""Flush stream"""
|
|
self.level(sys.stderr)
|
|
|
|
|
|
def _setup_logger(mode: str, log_dir: str) -> None:
|
|
if not os.path.isdir(log_dir):
|
|
invalid_log_dir = log_dir
|
|
log_dir = os.environ['AZ_BATCH_TASK_WORKING_DIR']
|
|
print('log directory "{}" '.format(invalid_log_dir) +
|
|
'is not valid: using "{}"'.format(log_dir))
|
|
logger_suffix = "" if mode is None else "-{}".format(mode)
|
|
logger_name = 'cascade{}-{}'.format(
|
|
logger_suffix, datetime.datetime.now().strftime('%Y%m%dT%H%M%S'))
|
|
global logger
|
|
logger = logging.getLogger(logger_name)
|
|
"""Set up logger"""
|
|
logger.setLevel(logging.DEBUG)
|
|
logloc = pathlib.Path(log_dir, '{}.log'.format(logger_name))
|
|
handler = logging.handlers.RotatingFileHandler(
|
|
str(logloc), maxBytes=10485760, backupCount=5)
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s.%(msecs)03dZ %(levelname)s %(filename)s::%(funcName)s:'
|
|
'%(lineno)d %(process)d:%(threadName)s %(message)s')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
# redirect stderr to logger
|
|
sys.stderr = StandardStreamLogger(logger.error)
|
|
logger.info('logger initialized, log file: {}'.format(logloc))
|
|
|
|
|
|
def _setup_storage_names(sep: str) -> None:
|
|
"""Set up storage names
|
|
:param str sep: storage container prefix
|
|
"""
|
|
global _PARTITION_KEY, _PREFIX
|
|
# transform pool id if necessary
|
|
poolid = os.environ['AZ_BATCH_POOL_ID'].lower()
|
|
autopool = os.environ.get('SHIPYARD_AUTOPOOL', default=None)
|
|
# remove guid portion of pool id if autopool
|
|
if autopool is not None:
|
|
poolid = poolid[:-37]
|
|
# set partition key
|
|
batchaccount = os.environ['AZ_BATCH_ACCOUNT_NAME'].lower()
|
|
_PARTITION_KEY = '{}${}'.format(batchaccount, poolid)
|
|
# set container names
|
|
if sep is None or len(sep) == 0:
|
|
raise ValueError('storage_entity_prefix is invalid')
|
|
_STORAGE_CONTAINERS['blob_globalresources'] = '-'.join(
|
|
(sep + 'gr', batchaccount, poolid))
|
|
_STORAGE_CONTAINERS['table_images'] = sep + 'images'
|
|
_STORAGE_CONTAINERS['table_globalresources'] = sep + 'gr'
|
|
_PREFIX = sep
|
|
|
|
|
|
def _create_credentials() -> tuple:
|
|
"""Create storage credentials
|
|
:rtype: tuple
|
|
:return: (blob_client, table_client)
|
|
"""
|
|
sa, ep, sakey = os.environ['SHIPYARD_STORAGE_ENV'].split(':')
|
|
blob_client = azureblob.BlockBlobService(
|
|
account_name=sa,
|
|
account_key=sakey,
|
|
endpoint_suffix=ep)
|
|
table_client = azuretable.TableService(
|
|
account_name=sa,
|
|
account_key=sakey,
|
|
endpoint_suffix=ep)
|
|
return blob_client, table_client
|
|
|
|
|
|
async def _record_perf_async(
|
|
loop: asyncio.BaseEventLoop, event: str, message: str) -> None:
|
|
"""Record timing metric async
|
|
:param asyncio.BaseEventLoop loop: event loop
|
|
:param str event: event
|
|
:param str message: message
|
|
"""
|
|
if not _RECORD_PERF:
|
|
return
|
|
proc = await asyncio.create_subprocess_shell(
|
|
'./perf.py cascade {ev} --prefix {pr} --message "{msg}"'.format(
|
|
ev=event, pr=_PREFIX, msg=message), loop=loop)
|
|
await proc.wait()
|
|
if proc.returncode != 0:
|
|
logger.error(
|
|
'could not record perf to storage for event: {}'.format(event))
|
|
|
|
|
|
def _record_perf(event: str, message: str) -> None:
|
|
"""Record timing metric
|
|
:param str event: event
|
|
:param str message: message
|
|
"""
|
|
if not _RECORD_PERF:
|
|
return
|
|
subprocess.check_call(
|
|
'./perf.py cascade {ev} --prefix {pr} --message "{msg}"'.format(
|
|
ev=event, pr=_PREFIX, msg=message), shell=True)
|
|
|
|
|
|
def _renew_blob_lease(
|
|
loop: asyncio.BaseEventLoop,
|
|
blob_client: azureblob.BlockBlobService,
|
|
container_key: str, resource: str, blob_name: str):
|
|
"""Renew a storage blob lease
|
|
:param asyncio.BaseEventLoop loop: event loop
|
|
:param azureblob.BlockBlobService blob_client: blob client
|
|
:param str container_key: blob container index into _STORAGE_CONTAINERS
|
|
:param str resource: resource
|
|
:param str blob_name: blob name
|
|
"""
|
|
try:
|
|
lease_id = blob_client.renew_blob_lease(
|
|
container_name=_STORAGE_CONTAINERS[container_key],
|
|
blob_name=blob_name,
|
|
lease_id=_BLOB_LEASES[resource],
|
|
)
|
|
except azure.common.AzureException as e:
|
|
logger.exception(e)
|
|
_BLOB_LEASES.pop(resource)
|
|
_CBHANDLES.pop(resource)
|
|
else:
|
|
_BLOB_LEASES[resource] = lease_id
|
|
_CBHANDLES[resource] = loop.call_later(
|
|
15, _renew_blob_lease, loop, blob_client, container_key, resource,
|
|
blob_name)
|
|
|
|
|
|
def scantree(path):
|
|
"""Recursively scan a directory tree
|
|
:param str path: path to scan
|
|
:rtype: os.DirEntry
|
|
:return: DirEntry via generator
|
|
"""
|
|
for entry in os.scandir(path):
|
|
yield entry
|
|
if entry.is_dir(follow_symlinks=False):
|
|
yield from scantree(entry.path)
|
|
|
|
|
|
def get_container_image_name_from_resource(resource: str) -> Tuple[str, str]:
|
|
"""Get container image from resource id
|
|
:param str resource: resource
|
|
:rtype: tuple
|
|
:return: (type, image name)
|
|
"""
|
|
if resource.startswith(_DOCKER_TAG):
|
|
return (
|
|
'docker',
|
|
resource[len(_DOCKER_TAG):]
|
|
)
|
|
elif resource.startswith(_SINGULARITY_TAG):
|
|
return (
|
|
'singularity',
|
|
resource[len(_SINGULARITY_TAG):]
|
|
)
|
|
else:
|
|
raise ValueError('invalid resource: {}'.format(resource))
|
|
|
|
|
|
def is_container_resource(resource: str) -> bool:
|
|
"""Check if resource is a container resource
|
|
:param str resource: resource
|
|
:rtype: bool
|
|
:return: is a supported resource
|
|
"""
|
|
if (resource.startswith(_DOCKER_TAG) or
|
|
resource.startswith(_SINGULARITY_TAG)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def compute_resource_hash(resource: str) -> str:
|
|
"""Calculate compute resource hash
|
|
:param str resource: resource
|
|
:rtype: str
|
|
:return: hash of resource
|
|
"""
|
|
return hashlib.sha1(resource.encode('utf8')).hexdigest()
|
|
|
|
|
|
def _singularity_image_name_on_disk(name: str) -> str:
|
|
"""Convert a singularity URI to an on disk sif name
|
|
:param str name: Singularity image name
|
|
:rtype: str
|
|
:return: singularity image name on disk
|
|
"""
|
|
docker = False
|
|
if name.startswith('shub://'):
|
|
name = name[7:]
|
|
elif name.startswith('library://'):
|
|
name = name[10:]
|
|
elif name.startswith('oras://'):
|
|
name = name[7:]
|
|
elif name.startswith('docker://'):
|
|
docker = True
|
|
name = name[9:]
|
|
# singularity only uses the final portion
|
|
name = name.split('/')[-1]
|
|
name = name.replace('/', '-')
|
|
if docker:
|
|
name = name.replace(':', '-')
|
|
name = '{}.sif'.format(name)
|
|
else:
|
|
tmp = name.split(':')
|
|
if len(tmp) > 1:
|
|
name = '{}_{}.sif'.format(tmp[0], tmp[1])
|
|
else:
|
|
name = '{}_latest.sif'.format(name)
|
|
return name
|
|
|
|
|
|
def singularity_image_path_on_disk(name: str) -> pathlib.Path:
|
|
"""Get a singularity image path on disk
|
|
:param str name: Singularity image name
|
|
:rtype: pathlib.Path
|
|
:return: singularity image path on disk
|
|
"""
|
|
return _SINGULARITY_CACHE_DIR / _singularity_image_name_on_disk(name)
|
|
|
|
|
|
def singularity_image_name_to_key_file_name(name: str) -> str:
|
|
"""Convert a singularity image to its key file name
|
|
:param str name: Singularity image name
|
|
:rtype: str
|
|
:return: key file name of the singularity image
|
|
"""
|
|
hash_image_name = compute_resource_hash(name)
|
|
key_file_name = 'public-{}.asc'.format(hash_image_name)
|
|
return key_file_name
|
|
|
|
|
|
class ContainerImageSaveThread(threading.Thread):
|
|
"""Container Image Save Thread"""
|
|
def __init__(
|
|
self, blob_client: azureblob.BlockBlobService,
|
|
table_client: azuretable.TableService,
|
|
resource: str, blob_name: str, nglobalresources: int):
|
|
"""ContainerImageSaveThread ctor
|
|
:param azureblob.BlockBlobService blob_client: blob client
|
|
:param azuretable.TableService table_client: table client
|
|
:param str resource: resource
|
|
:param str blob_name: resource blob name
|
|
:param int nglobalresources: number of global resources
|
|
"""
|
|
threading.Thread.__init__(self)
|
|
self.blob_client = blob_client
|
|
self.table_client = table_client
|
|
self.resource = resource
|
|
self.blob_name = blob_name
|
|
self.nglobalresources = nglobalresources
|
|
# add to downloading set
|
|
with _DIRECTDL_LOCK:
|
|
_DIRECTDL_DOWNLOADING.add(self.resource)
|
|
|
|
def run(self) -> None:
|
|
"""Thread main run function"""
|
|
try:
|
|
self._pull_and_save()
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
_THREAD_EXCEPTIONS.append(ex)
|
|
finally:
|
|
# cancel callback
|
|
try:
|
|
_CBHANDLES[self.resource].cancel()
|
|
except KeyError as e:
|
|
logger.exception(e)
|
|
_CBHANDLES.pop(self.resource)
|
|
# release blob lease
|
|
try:
|
|
self.blob_client.release_blob_lease(
|
|
container_name=_STORAGE_CONTAINERS['blob_globalresources'],
|
|
blob_name=self.blob_name,
|
|
lease_id=_BLOB_LEASES[self.resource],
|
|
)
|
|
except azure.common.AzureException as e:
|
|
logger.exception(e)
|
|
_BLOB_LEASES.pop(self.resource)
|
|
logger.debug(
|
|
'blob lease released for {}'.format(self.resource))
|
|
# remove from downloading set
|
|
with _DIRECTDL_LOCK:
|
|
_DIRECTDL_DOWNLOADING.remove(self.resource)
|
|
|
|
def _check_pull_output_overload(self, stderr: str) -> bool:
|
|
"""Check output for registry overload errors
|
|
:param str stderr: stderr
|
|
:rtype: bool
|
|
:return: if error appears to be overload from registry
|
|
"""
|
|
return any([x in stderr for x in _DOCKER_PULL_ERRORS])
|
|
|
|
def _get_singularity_credentials(self, image: str) -> tuple:
|
|
"""Get the username and the password of the registry of a given
|
|
Singularity image
|
|
:param str image: image for which we want the username and the
|
|
password
|
|
:rtype: tuple
|
|
:return: username and password
|
|
"""
|
|
global _DOCKER_AUTH_MAP
|
|
registry_type, _, image_name = image.partition('://')
|
|
if registry_type != 'docker' and registry_type != 'oras':
|
|
return None, None
|
|
docker_config_data = {}
|
|
with _DOCKER_AUTH_MAP_LOCK:
|
|
if _DOCKER_AUTH_MAP is None:
|
|
with open(_DOCKER_CONFIG_FILE) as docker_config_file:
|
|
docker_config_data = json.load(docker_config_file)
|
|
try:
|
|
_DOCKER_AUTH_MAP = docker_config_data['auths']
|
|
except KeyError:
|
|
_DOCKER_AUTH_MAP = {}
|
|
registry = image_name.partition('/')[0]
|
|
try:
|
|
b64auth = _DOCKER_AUTH_MAP[registry]['auth']
|
|
except KeyError:
|
|
return None, None
|
|
auth = base64.b64decode(b64auth).decode('utf-8')
|
|
username, _, password = auth.partition(':')
|
|
return username, password
|
|
|
|
def _get_singularity_pull_cmd(self, image: str) -> str:
|
|
"""Get singularity pull command
|
|
:param str image: image to pull
|
|
:rtype: str
|
|
:return: pull command for the singularity image
|
|
"""
|
|
# if we have a key_fingerprint we need to pull
|
|
# the key to our keyring
|
|
image_out_path = singularity_image_path_on_disk(image)
|
|
key_file_path = pathlib.Path(
|
|
singularity_image_name_to_key_file_name(image))
|
|
username, password = self._get_singularity_credentials(image)
|
|
if username is not None and password is not None:
|
|
credentials_command_argument = (
|
|
'--docker-username {} --docker-password {} '.format(
|
|
username, password))
|
|
else:
|
|
credentials_command_argument = ''
|
|
if image in _DIRECTDL_KEY_FINGERPRINT_DICT:
|
|
singularity_pull_cmd = (
|
|
'singularity pull -F ' +
|
|
credentials_command_argument +
|
|
'{} {}'.format(image_out_path, image))
|
|
key_fingerprint = _DIRECTDL_KEY_FINGERPRINT_DICT[image]
|
|
if key_file_path.is_file():
|
|
key_import_cmd = ('singularity key import {}'
|
|
.format(key_file_path))
|
|
fingerprint_check_cmd = (
|
|
'key_fingerprint=$({} | '.format(key_import_cmd) +
|
|
'grep -o "fingerprint \\(\\S*\\)" | ' +
|
|
'grep -o "\\S*$" | sed -e "s/\\(.*\\)/\\U\\1/"); ' +
|
|
'if [ ${key_fingerprint} != ' +
|
|
'"{}" ]; '.format(key_fingerprint.upper()) +
|
|
'then (>&2 echo "aborting: fingerprint of ' +
|
|
'key file $key_fingerprint does not match ' +
|
|
'fingerprint provided {}")'.format(key_fingerprint) +
|
|
' && exit 1; fi')
|
|
cmd = (key_import_cmd + ' && ' + fingerprint_check_cmd +
|
|
' && ' + singularity_pull_cmd)
|
|
else:
|
|
key_pull_cmd = ('singularity key pull {}'
|
|
.format(key_fingerprint))
|
|
cmd = key_pull_cmd + ' && ' + singularity_pull_cmd
|
|
# if the image pulled from oras we need to manually
|
|
# verify the image
|
|
if image.startswith('oras://'):
|
|
singularity_verify_cmd = ('singularity verify {}'
|
|
.format(image_out_path))
|
|
cmd = cmd + ' && ' + singularity_verify_cmd
|
|
else:
|
|
cmd = ('singularity pull -U -F ' +
|
|
credentials_command_argument +
|
|
'{} {}'.format(image_out_path, image))
|
|
return cmd
|
|
|
|
def _pull(self, grtype: str, image: str) -> tuple:
|
|
"""Container image pull
|
|
:param str grtype: global resource type
|
|
:param str image: image to pull
|
|
:rtype: tuple
|
|
:return: tuple or return code, stdout, stderr
|
|
"""
|
|
if grtype == 'docker':
|
|
cmd = 'docker pull {}'.format(image)
|
|
elif grtype == 'singularity':
|
|
cmd = self._get_singularity_pull_cmd(image)
|
|
logger.debug('pulling command: {}'.format(cmd))
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
shell=True,
|
|
universal_newlines=True)
|
|
stdout, stderr = proc.communicate()
|
|
return proc.returncode, stdout, stderr
|
|
|
|
def _pull_and_save(self) -> None:
|
|
"""Thread main logic for pulling and saving a container image"""
|
|
grtype, image = get_container_image_name_from_resource(self.resource)
|
|
_record_perf('pull-start', 'grtype={},img={}'.format(grtype, image))
|
|
start = datetime.datetime.now()
|
|
logger.info('pulling {} image {}'.format(grtype, image))
|
|
backoff = random.randint(2, 5)
|
|
while True:
|
|
rc, stdout, stderr = self._pull(grtype, image)
|
|
if rc == 0:
|
|
break
|
|
elif self._check_pull_output_overload(stderr.lower()):
|
|
logger.error(
|
|
'Too many requests issued to registry server, '
|
|
'retrying...')
|
|
backoff = backoff << 1
|
|
endbackoff = backoff << 1
|
|
if endbackoff >= 300:
|
|
endbackoff = 300
|
|
if backoff > endbackoff:
|
|
backoff = endbackoff
|
|
time.sleep(random.randint(backoff, endbackoff))
|
|
# reset if backoff reaches 5 min
|
|
if backoff >= 300:
|
|
backoff = random.randint(2, 5)
|
|
else:
|
|
raise RuntimeError(
|
|
'{} pull failed: stdout={} stderr={}'.format(
|
|
grtype, stdout, stderr))
|
|
diff = (datetime.datetime.now() - start).total_seconds()
|
|
logger.debug('took {} sec to pull {} image {}'.format(
|
|
diff, grtype, image))
|
|
# register resource
|
|
_merge_resource(
|
|
self.table_client, self.resource, self.nglobalresources)
|
|
# get image size
|
|
try:
|
|
if grtype == 'docker':
|
|
output = subprocess.check_output(
|
|
'docker images {}'.format(image), shell=True)
|
|
size = ' '.join(output.decode('utf-8').split()[-2:])
|
|
elif grtype == 'singularity':
|
|
imgpath = singularity_image_path_on_disk(image)
|
|
size = imgpath.stat().st_size
|
|
_record_perf(
|
|
'pull-end', 'grtype={},img={},diff={},size={}'.format(
|
|
grtype, image, diff, size))
|
|
except subprocess.CalledProcessError as ex:
|
|
logger.exception(ex)
|
|
_record_perf('pull-end', 'grtype={},img={},diff={}'.format(
|
|
grtype, image, diff))
|
|
|
|
|
|
async def _direct_download_resources_async(
|
|
loop: asyncio.BaseEventLoop,
|
|
blob_client: azureblob.BlockBlobService,
|
|
table_client: azuretable.TableService,
|
|
nglobalresources: int) -> None:
|
|
"""Direct download resource logic
|
|
:param asyncio.BaseEventLoop loop: event loop
|
|
:param azureblob.BlockBlobService blob_client: blob client
|
|
:param azuretable.TableService table_client: table client
|
|
:param int nglobalresources: number of global resources
|
|
"""
|
|
# ensure we are not downloading too many sources at once
|
|
with _DIRECTDL_LOCK:
|
|
if len(_DIRECTDL_DOWNLOADING) > _CONCURRENT_DOWNLOADS_ALLOWED:
|
|
return
|
|
# retrieve a resource from dl queue
|
|
_seen = set()
|
|
while True:
|
|
try:
|
|
resource = _DIRECTDL_QUEUE.get()
|
|
except queue.Empty:
|
|
break
|
|
else:
|
|
if resource in _seen:
|
|
_DIRECTDL_QUEUE.put(resource)
|
|
resource = None
|
|
break
|
|
_seen.add(resource)
|
|
with _DIRECTDL_LOCK:
|
|
if resource not in _DIRECTDL_DOWNLOADING:
|
|
break
|
|
else:
|
|
_DIRECTDL_QUEUE.put(resource)
|
|
resource = None
|
|
del _seen
|
|
# attempt to get a blob lease
|
|
if resource is not None:
|
|
lease_id = None
|
|
blob_name = None
|
|
for i in range(0, _CONCURRENT_DOWNLOADS_ALLOWED):
|
|
blob_name = '{}.{}'.format(compute_resource_hash(resource), i)
|
|
try:
|
|
lease_id = blob_client.acquire_blob_lease(
|
|
container_name=_STORAGE_CONTAINERS['blob_globalresources'],
|
|
blob_name=blob_name,
|
|
lease_duration=60,
|
|
)
|
|
break
|
|
except azure.common.AzureConflictHttpError:
|
|
blob_name = None
|
|
pass
|
|
if lease_id is None:
|
|
logger.debug(
|
|
'no available blobs to lease for resource: {}'.format(
|
|
resource))
|
|
_DIRECTDL_QUEUE.put(resource)
|
|
return
|
|
# create lease renew callback
|
|
logger.debug('blob lease {} acquired for resource {}'.format(
|
|
lease_id, resource))
|
|
_BLOB_LEASES[resource] = lease_id
|
|
_CBHANDLES[resource] = loop.call_later(
|
|
15, _renew_blob_lease, loop, blob_client, 'blob_globalresources',
|
|
resource, blob_name)
|
|
if resource is None:
|
|
return
|
|
# pull and save container image in thread
|
|
if is_container_resource(resource):
|
|
thr = ContainerImageSaveThread(
|
|
blob_client, table_client, resource, blob_name, nglobalresources)
|
|
thr.start()
|
|
else:
|
|
# TODO download via blob, explode uri to get container/blob
|
|
# use download to path into /tmp and move to directory
|
|
raise NotImplementedError()
|
|
|
|
|
|
def _merge_resource(
|
|
table_client: azuretable.TableService,
|
|
resource: str, nglobalresources: int) -> None:
|
|
"""Merge resource to the image table
|
|
:param azuretable.TableService table_client: table client
|
|
:param str resource: resource to add to the image table
|
|
:param int nglobalresources: number of global resources
|
|
"""
|
|
# merge resource to the image table
|
|
entity = {
|
|
'PartitionKey': _PARTITION_KEY,
|
|
'RowKey': compute_resource_hash(resource),
|
|
'Resource': resource,
|
|
'VmList0': _NODEID,
|
|
}
|
|
logger.debug('merging entity {} to the image table'.format(entity))
|
|
try:
|
|
table_client.insert_entity(
|
|
_STORAGE_CONTAINERS['table_images'], entity=entity)
|
|
except azure.common.AzureConflictHttpError:
|
|
while True:
|
|
entity = table_client.get_entity(
|
|
_STORAGE_CONTAINERS['table_images'],
|
|
entity['PartitionKey'], entity['RowKey'])
|
|
# merge VmList into entity
|
|
evms = []
|
|
for i in range(0, _MAX_VMLIST_PROPERTIES):
|
|
prop = 'VmList{}'.format(i)
|
|
if prop in entity:
|
|
evms.extend(entity[prop].split(','))
|
|
if _NODEID in evms:
|
|
break
|
|
evms.append(_NODEID)
|
|
for i in range(0, _MAX_VMLIST_PROPERTIES):
|
|
prop = 'VmList{}'.format(i)
|
|
start = i * _MAX_VMLIST_IDS_PER_PROPERTY
|
|
end = start + _MAX_VMLIST_IDS_PER_PROPERTY
|
|
if end > len(evms):
|
|
end = len(evms)
|
|
if start < end:
|
|
entity[prop] = ','.join(evms[start:end])
|
|
else:
|
|
entity[prop] = None
|
|
etag = entity['etag']
|
|
entity.pop('etag')
|
|
try:
|
|
table_client.merge_entity(
|
|
_STORAGE_CONTAINERS['table_images'], entity=entity,
|
|
if_match=etag)
|
|
break
|
|
except azure.common.AzureHttpError as ex:
|
|
if ex.status_code != 412:
|
|
raise
|
|
logger.info('entity {} merged to the image table'.format(entity))
|
|
global _GR_DONE
|
|
if not _GR_DONE:
|
|
try:
|
|
entities = table_client.query_entities(
|
|
_STORAGE_CONTAINERS['table_images'],
|
|
filter='PartitionKey eq \'{}\''.format(_PARTITION_KEY))
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
entities = []
|
|
count = 0
|
|
for entity in entities:
|
|
for i in range(0, _MAX_VMLIST_PROPERTIES):
|
|
prop = 'VmList{}'.format(i)
|
|
mode_prefix = _CONTAINER_MODE.name.lower() + ':'
|
|
if (prop in entity and _NODEID in entity[prop] and
|
|
entity['Resource'].startswith(mode_prefix)):
|
|
count += 1
|
|
if count == nglobalresources:
|
|
_record_perf(
|
|
'gr-done',
|
|
'nglobalresources={}'.format(nglobalresources))
|
|
_GR_DONE = True
|
|
logger.info('all {} global resources of container mode "{}" loaded'
|
|
.format(nglobalresources,
|
|
_CONTAINER_MODE.name.lower()))
|
|
else:
|
|
logger.info('{}/{} global resources of container mode "{}" loaded'
|
|
.format(count, nglobalresources,
|
|
_CONTAINER_MODE.name.lower()))
|
|
|
|
|
|
def _unmerge_resources(
|
|
table_client: azuretable.TableService) -> None:
|
|
"""Remove node from the image table
|
|
:param azuretable.TableService table_client: table client
|
|
"""
|
|
logger.debug('removing node {} from the image table for container mode {}'
|
|
.format(_NODEID, _CONTAINER_MODE.name.lower()))
|
|
try:
|
|
entities = table_client.query_entities(
|
|
_STORAGE_CONTAINERS['table_images'],
|
|
filter='PartitionKey eq \'{}\''.format(_PARTITION_KEY))
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
entities = []
|
|
mode_prefix = _CONTAINER_MODE.name.lower() + ':'
|
|
for entity in entities:
|
|
if entity['Resource'].startswith(mode_prefix):
|
|
_unmerge_resource(table_client, entity)
|
|
logger.info('node {} removed from the image table for container mode {}'
|
|
.format(_NODEID, _CONTAINER_MODE.name.lower()))
|
|
|
|
|
|
def _unmerge_resource(
|
|
table_client: azuretable.TableService, entity: dict) -> None:
|
|
"""Remove node from entity
|
|
:param azuretable.TableService table_client: table client
|
|
"""
|
|
while True:
|
|
entity = table_client.get_entity(
|
|
_STORAGE_CONTAINERS['table_images'],
|
|
entity['PartitionKey'], entity['RowKey'])
|
|
# merge VmList into entity
|
|
evms = []
|
|
for i in range(0, _MAX_VMLIST_PROPERTIES):
|
|
prop = 'VmList{}'.format(i)
|
|
if prop in entity:
|
|
evms.extend(entity[prop].split(','))
|
|
if _NODEID in evms:
|
|
evms.remove(_NODEID)
|
|
for i in range(0, _MAX_VMLIST_PROPERTIES):
|
|
prop = 'VmList{}'.format(i)
|
|
start = i * _MAX_VMLIST_IDS_PER_PROPERTY
|
|
end = start + _MAX_VMLIST_IDS_PER_PROPERTY
|
|
if end > len(evms):
|
|
end = len(evms)
|
|
if start < end:
|
|
entity[prop] = ','.join(evms[start:end])
|
|
else:
|
|
entity[prop] = None
|
|
etag = entity['etag']
|
|
entity.pop('etag')
|
|
try:
|
|
table_client.update_entity(
|
|
_STORAGE_CONTAINERS['table_images'], entity=entity,
|
|
if_match=etag)
|
|
break
|
|
except azure.common.AzureHttpError as ex:
|
|
if ex.status_code != 412:
|
|
raise
|
|
|
|
|
|
async def download_monitor_async(
|
|
loop: asyncio.BaseEventLoop,
|
|
blob_client: azureblob.BlockBlobService,
|
|
table_client: azuretable.TableService,
|
|
nglobalresources: int) -> None:
|
|
"""Download monitor
|
|
:param asyncio.BaseEventLoop loop: event loop
|
|
:param azureblob.BlockBlobService blob_client: blob client
|
|
:param azuretable.TableService table_client: table client
|
|
:param int nglobalresource: number of global resources
|
|
"""
|
|
while not _GR_DONE:
|
|
# check if there are any direct downloads
|
|
if _DIRECTDL_QUEUE.qsize() > 0:
|
|
await _direct_download_resources_async(
|
|
loop, blob_client, table_client, nglobalresources)
|
|
# check for any thread exceptions
|
|
if len(_THREAD_EXCEPTIONS) > 0:
|
|
logger.critical('Thread exceptions encountered, terminating')
|
|
# raise first exception
|
|
raise _THREAD_EXCEPTIONS[0]
|
|
# sleep to avoid pinning cpu
|
|
await asyncio.sleep(1)
|
|
# fixup filemodes/ownership for singularity images
|
|
if (_SINGULARITY_CACHE_DIR is not None and
|
|
_AZBATCH_USER is not None):
|
|
if _SINGULARITY_CACHE_DIR.exists():
|
|
logger.info('chown all files in {}'.format(
|
|
_SINGULARITY_CACHE_DIR))
|
|
for file in scantree(str(_SINGULARITY_CACHE_DIR)):
|
|
os.chown(
|
|
str(file.path),
|
|
_AZBATCH_USER[2],
|
|
_AZBATCH_USER[3]
|
|
)
|
|
else:
|
|
logger.warning(
|
|
'singularity cache dir {} does not exist'.format(
|
|
_SINGULARITY_CACHE_DIR))
|
|
# fixup filemodes/ownership for singularity keys
|
|
if (_SINGULARITY_SYPGP_DIR is not None and
|
|
_AZBATCH_USER is not None):
|
|
if _SINGULARITY_SYPGP_DIR.exists():
|
|
logger.info('chown all files in {}'.format(
|
|
_SINGULARITY_SYPGP_DIR))
|
|
for file in scantree(str(_SINGULARITY_SYPGP_DIR)):
|
|
os.chown(
|
|
str(file.path),
|
|
_AZBATCH_USER[2],
|
|
_AZBATCH_USER[3]
|
|
)
|
|
else:
|
|
logger.warning(
|
|
'singularity sypgp dir {} does not exist'.format(
|
|
_SINGULARITY_SYPGP_DIR))
|
|
|
|
|
|
def distribute_global_resources(
|
|
loop: asyncio.BaseEventLoop,
|
|
blob_client: azureblob.BlockBlobService,
|
|
table_client: azuretable.TableService) -> None:
|
|
"""Distribute global services/resources
|
|
:param asyncio.BaseEventLoop loop: event loop
|
|
:param azureblob.BlockBlobService blob_client: blob client
|
|
:param azuretable.TableService table_client: table client
|
|
"""
|
|
# remove node from the image table because cascade relies on it to know
|
|
# when its work is done
|
|
_unmerge_resources(table_client)
|
|
# get globalresources from table
|
|
try:
|
|
entities = table_client.query_entities(
|
|
_STORAGE_CONTAINERS['table_globalresources'],
|
|
filter='PartitionKey eq \'{}\''.format(_PARTITION_KEY))
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
entities = []
|
|
nentities = 0
|
|
for ent in entities:
|
|
resource = ent['Resource']
|
|
grtype, image = get_container_image_name_from_resource(resource)
|
|
if grtype == _CONTAINER_MODE.name.lower():
|
|
nentities += 1
|
|
_DIRECTDL_QUEUE.put(resource)
|
|
key_fingerprint = ent.get('KeyFingerprint', None)
|
|
if key_fingerprint is not None:
|
|
_DIRECTDL_KEY_FINGERPRINT_DICT[image] = key_fingerprint
|
|
else:
|
|
logger.info('skipping resource {}:'.format(resource) +
|
|
'not matching container mode "{}"'
|
|
.format(_CONTAINER_MODE.name.lower()))
|
|
if nentities == 0:
|
|
logger.info('no global resources specified')
|
|
return
|
|
logger.info('{} global resources matching container mode "{}"'
|
|
.format(nentities, _CONTAINER_MODE.name.lower()))
|
|
# run async func in loop
|
|
loop.run_until_complete(download_monitor_async(
|
|
loop, blob_client, table_client, nentities))
|
|
|
|
|
|
def main():
|
|
"""Main function"""
|
|
# get command-line args
|
|
args = parseargs()
|
|
|
|
_setup_logger(args.mode, args.log_directory)
|
|
|
|
global _CONCURRENT_DOWNLOADS_ALLOWED, _CONTAINER_MODE
|
|
|
|
# set up concurrent source downloads
|
|
if args.concurrent is None:
|
|
raise ValueError('concurrent source downloads is not specified')
|
|
try:
|
|
_CONCURRENT_DOWNLOADS_ALLOWED = int(args.concurrent)
|
|
except ValueError:
|
|
_CONCURRENT_DOWNLOADS_ALLOWED = None
|
|
if (_CONCURRENT_DOWNLOADS_ALLOWED is None or
|
|
_CONCURRENT_DOWNLOADS_ALLOWED <= 0):
|
|
raise ValueError('concurrent source downloads is invalid: {}'
|
|
.format(args.concurrent))
|
|
logger.info('max concurrent downloads: {}'.format(
|
|
_CONCURRENT_DOWNLOADS_ALLOWED))
|
|
|
|
# get event loop
|
|
if _ON_WINDOWS:
|
|
loop = asyncio.ProactorEventLoop()
|
|
asyncio.set_event_loop(loop)
|
|
else:
|
|
loop = asyncio.get_event_loop()
|
|
loop.set_debug(True)
|
|
|
|
# set up container mode
|
|
if args.mode is None:
|
|
raise ValueError('container mode is not specified')
|
|
if args.mode == 'docker':
|
|
_CONTAINER_MODE = ContainerMode.DOCKER
|
|
elif args.mode == 'singularity':
|
|
_CONTAINER_MODE = ContainerMode.SINGULARITY
|
|
else:
|
|
raise ValueError('container mode is invalid: {}'.format(args.mode))
|
|
logger.info('container mode: {}'.format(_CONTAINER_MODE.name))
|
|
|
|
# set up storage names
|
|
_setup_storage_names(args.prefix)
|
|
del args
|
|
|
|
# create storage credentials
|
|
blob_client, table_client = _create_credentials()
|
|
|
|
# distribute global resources
|
|
distribute_global_resources(loop, blob_client, table_client)
|
|
|
|
|
|
def parseargs():
|
|
"""Parse program arguments
|
|
:rtype: argparse.Namespace
|
|
:return: parsed arguments
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description='Cascade: Batch Shipyard File/Image Replicator')
|
|
parser.set_defaults(concurrent=None, mode=None)
|
|
parser.add_argument(
|
|
'--concurrent',
|
|
help='concurrent source downloads')
|
|
parser.add_argument(
|
|
'--mode', help='container mode (docker/singularity)')
|
|
parser.add_argument(
|
|
'--prefix', help='storage container prefix')
|
|
parser.add_argument(
|
|
'--log-directory', help='directory to store log files')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|