3238 строки
128 KiB
Python
Executable File
3238 строки
128 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Microsoft Corporation
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# MIT License
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
# copy of this software and associated documentation files (the "Software"),
|
|
# to deal in the Software without restriction, including without limitation
|
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
# and/or sell copies of the Software, and to permit persons to whom the
|
|
# Software is furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in
|
|
# all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
# DEALINGS IN THE SOFTWARE.
|
|
|
|
# stdlib imports
|
|
import argparse
|
|
import asyncio
|
|
import concurrent.futures
|
|
import datetime
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import logging.handlers
|
|
import multiprocessing
|
|
import pathlib
|
|
import pickle
|
|
import random
|
|
import subprocess
|
|
import threading
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Generator,
|
|
List,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
)
|
|
# non-stdlib imports
|
|
import azure.batch
|
|
import azure.batch.models as batchmodels
|
|
import azure.cosmosdb.table
|
|
import azure.mgmt.compute
|
|
import azure.mgmt.resource
|
|
import azure.mgmt.storage
|
|
import azure.storage.blob
|
|
import azure.storage.queue
|
|
import dateutil.tz
|
|
import msrestazure.azure_active_directory
|
|
import msrestazure.azure_cloud
|
|
|
|
# create logger
|
|
logger = logging.getLogger(__name__)
|
|
# global defines
|
|
_MEGABYTE = 1048576
|
|
_RDMA_INSTANCES = frozenset((
|
|
'standard_a8', 'standard_a9',
|
|
))
|
|
_RDMA_INSTANCE_SUFFIXES = frozenset((
|
|
'r', 'rs', 'rs_v2', 'rs_v3',
|
|
))
|
|
_GPU_INSTANCE_PREFIXES = frozenset((
|
|
'standard_nc', 'standard_nd', 'standard_nv',
|
|
))
|
|
_POOL_NATIVE_METADATA_NAME = 'BATCH_SHIPYARD_NATIVE_CONTAINER_POOL'
|
|
# TODO allow these maximums to be configurable
|
|
_MAX_EXECUTOR_WORKERS = min((multiprocessing.cpu_count() * 4, 32))
|
|
_MAX_TIMESPAN_POOL_UPDATE = datetime.timedelta(seconds=60)
|
|
_MAX_TIMESPAN_NODE_COUNTS_UPDATE = datetime.timedelta(seconds=10)
|
|
_MAX_TIMESPAN_ACTIVE_TASKS_COUNT_UPDATE = datetime.timedelta(seconds=20)
|
|
|
|
|
|
def _setup_logger(log) -> None:
|
|
"""Set up logger"""
|
|
log.setLevel(logging.DEBUG)
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s %(levelname)s %(name)s:%(funcName)s:%(lineno)d '
|
|
'%(message)s')
|
|
formatter.default_msec_format = '%s.%03d'
|
|
handler.setFormatter(formatter)
|
|
log.addHandler(handler)
|
|
|
|
|
|
def max_workers_for_executor(iterable: Any) -> int:
|
|
"""Get max number of workers for executor given an iterable
|
|
:param iterable: an iterable
|
|
:return: number of workers for executor
|
|
"""
|
|
return min((len(iterable), _MAX_EXECUTOR_WORKERS))
|
|
|
|
|
|
def is_none_or_empty(obj: Any) -> bool:
|
|
"""Determine if object is None or empty
|
|
:param obj: object
|
|
:return: if object is None or empty
|
|
"""
|
|
return obj is None or len(obj) == 0
|
|
|
|
|
|
def is_not_empty(obj: Any) -> bool:
|
|
"""Determine if object is not None and is length is > 0
|
|
:param obj: object
|
|
:return: if object is not None and length is > 0
|
|
"""
|
|
return obj is not None and len(obj) > 0
|
|
|
|
|
|
def datetime_utcnow(as_string: bool = False) -> datetime.datetime:
|
|
"""Returns a datetime now with UTC timezone
|
|
:param as_string: return as ISO8601 extended string
|
|
:return: datetime object representing now with UTC timezone
|
|
"""
|
|
dt = datetime.datetime.now(dateutil.tz.tzutc())
|
|
if as_string:
|
|
return dt.strftime('%Y%m%dT%H%M%S.%f')[:-3] + 'Z'
|
|
else:
|
|
return dt
|
|
|
|
|
|
def hash_string(strdata: str) -> str:
|
|
"""Hash a string
|
|
:param strdata: string data to hash
|
|
:return: hexdigest
|
|
"""
|
|
return hashlib.sha1(strdata.encode('utf8')).hexdigest()
|
|
|
|
|
|
def hash_federation_id(federation_id: str) -> str:
|
|
"""Hash a federation id
|
|
:param federation_id: federation id
|
|
:return: hashed federation id
|
|
"""
|
|
return hash_string(federation_id)
|
|
|
|
|
|
def is_rdma_pool(vm_size: str) -> bool:
|
|
"""Check if pool is IB/RDMA capable
|
|
:param vm_size: vm size
|
|
:return: if rdma is present
|
|
"""
|
|
vsl = vm_size.lower()
|
|
if vsl in _RDMA_INSTANCES:
|
|
return True
|
|
elif any(vsl.endswith(x) for x in _RDMA_INSTANCE_SUFFIXES):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_gpu_pool(vm_size: str) -> bool:
|
|
"""Check if pool is GPU capable
|
|
:param vm_size: vm size
|
|
:return: if gpus are present
|
|
"""
|
|
vsl = vm_size.lower()
|
|
return any(vsl.startswith(x) for x in _GPU_INSTANCE_PREFIXES)
|
|
|
|
|
|
def get_temp_disk_for_node_agent(node_agent: str) -> str:
|
|
"""Get temp disk location for node agent
|
|
:param node_agent: node agent
|
|
:return: temp disk location
|
|
"""
|
|
if node_agent.startswith('batch.node.unbuntu'):
|
|
return '/mnt'
|
|
elif node_agent.startswith('batch.node.windows'):
|
|
return 'D:\\batch'
|
|
else:
|
|
return '/mnt/resource'
|
|
|
|
|
|
class PoolConstraints():
|
|
def __init__(self, constraints: Dict[str, Any]) -> None:
|
|
autoscale = constraints.get('autoscale', {})
|
|
self.autoscale_allow = autoscale.get('allow')
|
|
self.autoscale_exclusive = autoscale.get('exclusive')
|
|
self.custom_image_arm_id = constraints.get('custom_image_arm_id')
|
|
self.location = constraints.get('location')
|
|
lp = constraints.get('low_priority_nodes', {})
|
|
self.low_priority_nodes_allow = lp.get('allow')
|
|
self.low_priority_nodes_exclusive = lp.get('exclusive')
|
|
matb = constraints.get('max_active_task_backlog', {})
|
|
self.max_active_task_backlog_ratio = matb.get('ratio')
|
|
self.max_active_task_backlog_autoscale_exempt = matb.get(
|
|
'autoscale_exempt')
|
|
self.native = constraints.get('native')
|
|
self.virtual_network_arm_id = constraints.get('virtual_network_arm_id')
|
|
self.windows = constraints.get('windows')
|
|
self.registries = constraints.get('registries')
|
|
|
|
|
|
class ComputeNodeConstraints():
|
|
def __init__(self, constraints: Dict[str, Any]) -> None:
|
|
self.vm_size = constraints.get('vm_size')
|
|
cores = constraints.get('cores', {})
|
|
self.cores = cores.get('amount')
|
|
if self.cores is not None:
|
|
self.cores = int(self.cores)
|
|
self.core_variance = cores.get('schedulable_variance')
|
|
memory = constraints.get('memory', {})
|
|
self.memory = memory.get('amount')
|
|
if self.memory is not None:
|
|
# normalize to MB
|
|
suffix = self.memory[-1].lower()
|
|
self.memory = int(self.memory[:-1])
|
|
if suffix == 'b':
|
|
self.memory /= _MEGABYTE
|
|
elif suffix == 'k':
|
|
self.memory /= 1024
|
|
elif suffix == 'g':
|
|
self.memory *= 1024
|
|
elif suffix == 't':
|
|
self.memory *= _MEGABYTE
|
|
else:
|
|
raise ValueError(
|
|
'invalid memory constraint suffix: {}'.format(suffix))
|
|
self.memory_variance = memory.get('schedulable_variance')
|
|
self.exclusive = constraints.get('exclusive')
|
|
self.gpu = constraints.get('gpu')
|
|
self.infiniband = constraints.get('infiniband')
|
|
|
|
|
|
class TaskConstraints():
|
|
def __init__(self, constraints: Dict[str, Any]) -> None:
|
|
self.auto_complete = constraints.get('auto_complete')
|
|
self.has_multi_instance = constraints.get('has_multi_instance')
|
|
self.has_task_dependencies = constraints.get('has_task_dependencies')
|
|
instance_counts = constraints.get('instance_counts', {})
|
|
self.instance_counts_max = instance_counts.get('max')
|
|
self.instance_counts_total = instance_counts.get('total')
|
|
self.merge_task_id = constraints.get('merge_task_id')
|
|
self.tasks_per_recurrence = constraints.get('tasks_per_recurrence')
|
|
|
|
|
|
class Constraints():
|
|
def __init__(self, constraints: Dict[str, Any]) -> None:
|
|
self.pool = PoolConstraints(constraints['pool'])
|
|
self.compute_node = ComputeNodeConstraints(constraints['compute_node'])
|
|
self.task = TaskConstraints(constraints['task'])
|
|
|
|
|
|
class TaskNaming():
|
|
def __init__(self, naming: Dict[str, Any]) -> None:
|
|
self.prefix = naming.get('prefix')
|
|
self.padding = naming.get('padding')
|
|
|
|
|
|
class Credentials():
|
|
def __init__(self, config: Dict[str, Any]) -> None:
|
|
"""Ctor for Credentials
|
|
:param config: configuration
|
|
"""
|
|
# set attr from config
|
|
self.storage_account = config['storage']['account']
|
|
self.storage_account_rg = config['storage']['resource_group']
|
|
# get cloud object
|
|
self.cloud = Credentials.convert_cloud_type(config['aad_cloud'])
|
|
# get aad creds
|
|
self.arm_creds = self.create_msi_credentials()
|
|
self.batch_creds = self.create_msi_credentials(
|
|
resource_id=self.cloud.endpoints.batch_resource_id)
|
|
# get subscription id
|
|
self.sub_id = self.get_subscription_id()
|
|
logger.debug('created msi auth for sub id: {}'.format(self.sub_id))
|
|
# get storage account key and endpoint
|
|
self.storage_account_key, self.storage_account_ep = \
|
|
self.get_storage_account_key()
|
|
logger.debug('storage account {} -> rg: {} ep: {}'.format(
|
|
self.storage_account, self.storage_account_rg,
|
|
self.storage_account_ep))
|
|
|
|
@staticmethod
|
|
def convert_cloud_type(cloud_type: str) -> msrestazure.azure_cloud.Cloud:
|
|
"""Convert clout type string to object
|
|
:param cloud_type: cloud type to convert
|
|
:return: cloud object
|
|
"""
|
|
if cloud_type == 'public':
|
|
cloud = msrestazure.azure_cloud.AZURE_PUBLIC_CLOUD
|
|
elif cloud_type == 'china':
|
|
cloud = msrestazure.azure_cloud.AZURE_CHINA_CLOUD
|
|
elif cloud_type == 'germany':
|
|
cloud = msrestazure.azure_cloud.AZURE_GERMAN_CLOUD
|
|
elif cloud_type == 'usgov':
|
|
cloud = msrestazure.azure_cloud.AZURE_US_GOV_CLOUD
|
|
else:
|
|
raise ValueError('unknown cloud_type: {}'.format(cloud_type))
|
|
return cloud
|
|
|
|
def get_subscription_id(self) -> str:
|
|
"""Get subscription id for ARM creds
|
|
:param arm_creds: ARM creds
|
|
:return: subscription id
|
|
"""
|
|
client = azure.mgmt.resource.SubscriptionClient(self.arm_creds)
|
|
return next(client.subscriptions.list()).subscription_id
|
|
|
|
def create_msi_credentials(
|
|
self,
|
|
resource_id: str = None
|
|
) -> msrestazure.azure_active_directory.MSIAuthentication:
|
|
"""Create MSI credentials
|
|
:param resource_id: resource id to auth against
|
|
:return: MSI auth object
|
|
"""
|
|
if is_not_empty(resource_id):
|
|
creds = msrestazure.azure_active_directory.MSIAuthentication(
|
|
cloud_environment=self.cloud,
|
|
resource=resource_id,
|
|
)
|
|
else:
|
|
creds = msrestazure.azure_active_directory.MSIAuthentication(
|
|
cloud_environment=self.cloud,
|
|
)
|
|
return creds
|
|
|
|
def get_storage_account_key(self) -> Tuple[str, str]:
|
|
"""Retrieve the storage account key and endpoint
|
|
:return: tuple of key, endpoint
|
|
"""
|
|
client = azure.mgmt.storage.StorageManagementClient(
|
|
self.arm_creds, self.sub_id,
|
|
base_url=self.cloud.endpoints.resource_manager)
|
|
ep = None
|
|
if is_not_empty(self.storage_account_rg):
|
|
acct = client.storage_accounts.get_properties(
|
|
self.storage_account_rg, self.storage_account)
|
|
ep = '.'.join(
|
|
acct.primary_endpoints.blob.rstrip('/').split('.')[2:]
|
|
)
|
|
else:
|
|
for acct in client.storage_accounts.list():
|
|
if acct.name == self.storage_account:
|
|
self.storage_account_rg = acct.id.split('/')[4]
|
|
ep = '.'.join(
|
|
acct.primary_endpoints.blob.rstrip('/').split('.')[2:]
|
|
)
|
|
break
|
|
if is_none_or_empty(self.storage_account_rg) or is_none_or_empty(ep):
|
|
raise RuntimeError(
|
|
'storage account {} not found in subscription id {}'.format(
|
|
self.storage_account, self.sub_id))
|
|
keys = client.storage_accounts.list_keys(
|
|
self.storage_account_rg, self.storage_account)
|
|
return (keys.keys[0].value, ep)
|
|
|
|
|
|
class ServiceProxy():
|
|
def __init__(self, config: Dict[str, Any]) -> None:
|
|
"""Ctor for ServiceProxy
|
|
:param config: configuration
|
|
"""
|
|
self._config = config
|
|
prefix = config['storage']['entity_prefix']
|
|
self.queue_prefix = '{}fed'.format(prefix)
|
|
self.table_name_global = '{}fedglobal'.format(prefix)
|
|
self.table_name_jobs = '{}fedjobs'.format(prefix)
|
|
self.blob_container_data_prefix = '{}fed'.format(prefix)
|
|
self.blob_container_name_global = '{}fedglobal'.format(prefix)
|
|
self.file_share_logging = '{}fedlogs'.format(prefix)
|
|
self._batch_client_lock = threading.Lock()
|
|
self.batch_clients = {}
|
|
# create credentials
|
|
self.creds = Credentials(config)
|
|
# create clients
|
|
self.compute_client = self._create_compute_client()
|
|
self.blob_client = self._create_blob_client()
|
|
self.table_client = self._create_table_client()
|
|
self.queue_client = self._create_queue_client()
|
|
logger.debug('created storage clients for storage account {}'.format(
|
|
self.creds.storage_account))
|
|
|
|
@property
|
|
def batch_shipyard_version(self) -> str:
|
|
return self._config['batch_shipyard']['version']
|
|
|
|
@property
|
|
def batch_shipyard_var_path(self) -> pathlib.Path:
|
|
return pathlib.Path(self._config['batch_shipyard']['var_path'])
|
|
|
|
@property
|
|
def storage_entity_prefix(self) -> str:
|
|
return self._config['storage']['entity_prefix']
|
|
|
|
@property
|
|
def logger_level(self) -> str:
|
|
return self._config['logging']['level']
|
|
|
|
@property
|
|
def logger_persist(self) -> bool:
|
|
return self._config['logging']['persistence']
|
|
|
|
@property
|
|
def logger_filename(self) -> bool:
|
|
return self._config['logging']['filename']
|
|
|
|
def log_configuration(self) -> None:
|
|
logger.debug('configuration: {}'.format(
|
|
json.dumps(self._config, sort_keys=True, indent=4)))
|
|
|
|
def _modify_client_for_retry_and_user_agent(self, client: Any) -> None:
|
|
"""Extend retry policy of clients and add user agent string
|
|
:param client: a client object
|
|
"""
|
|
if client is None:
|
|
return
|
|
client.config.retry_policy.max_backoff = 8
|
|
client.config.retry_policy.retries = 100
|
|
client.config.add_user_agent('batch-shipyard/{}'.format(
|
|
self.batch_shipyard_version))
|
|
|
|
def _create_table_client(self) -> azure.cosmosdb.table.TableService:
|
|
"""Create a table client for the given storage account
|
|
:return: table client
|
|
"""
|
|
client = azure.cosmosdb.table.TableService(
|
|
account_name=self.creds.storage_account,
|
|
account_key=self.creds.storage_account_key,
|
|
endpoint_suffix=self.creds.storage_account_ep,
|
|
)
|
|
return client
|
|
|
|
def _create_queue_client(self) -> azure.storage.queue.QueueService:
|
|
"""Create a queue client for the given storage account
|
|
:return: queue client
|
|
"""
|
|
client = azure.storage.queue.QueueService(
|
|
account_name=self.creds.storage_account,
|
|
account_key=self.creds.storage_account_key,
|
|
endpoint_suffix=self.creds.storage_account_ep,
|
|
)
|
|
return client
|
|
|
|
def _create_blob_client(self) -> azure.storage.blob.BlockBlobService:
|
|
"""Create a blob client for the given storage account
|
|
:return: block blob client
|
|
"""
|
|
return azure.storage.blob.BlockBlobService(
|
|
account_name=self.creds.storage_account,
|
|
account_key=self.creds.storage_account_key,
|
|
endpoint_suffix=self.creds.storage_account_ep,
|
|
)
|
|
|
|
def _create_compute_client(
|
|
self
|
|
) -> azure.mgmt.compute.ComputeManagementClient:
|
|
"""Create a compute mgmt client
|
|
:return: compute client
|
|
"""
|
|
client = azure.mgmt.compute.ComputeManagementClient(
|
|
self.creds.arm_creds, self.creds.sub_id,
|
|
base_url=self.creds.cloud.endpoints.resource_manager)
|
|
return client
|
|
|
|
def batch_client(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str
|
|
) -> azure.batch.BatchServiceClient:
|
|
"""Get/create batch client
|
|
:param batch_account: batch account name
|
|
:param service_url: service url
|
|
:return: batch client
|
|
"""
|
|
with self._batch_client_lock:
|
|
try:
|
|
return self.batch_clients[batch_account]
|
|
except KeyError:
|
|
client = azure.batch.BatchServiceClient(
|
|
self.creds.batch_creds, batch_url=service_url)
|
|
self._modify_client_for_retry_and_user_agent(client)
|
|
self.batch_clients[batch_account] = client
|
|
logger.debug('batch client created for account: {}'.format(
|
|
batch_account))
|
|
return client
|
|
|
|
|
|
class ComputeServiceHandler():
|
|
def __init__(self, service_proxy: ServiceProxy) -> None:
|
|
"""Ctor for Compute Service handler
|
|
:param service_proxy: ServiceProxy
|
|
"""
|
|
self.service_proxy = service_proxy
|
|
self._vm_sizes_lock = threading.Lock()
|
|
self._queried_locations = set()
|
|
self._vm_sizes = {}
|
|
|
|
def populate_vm_sizes_from_location(self, location: str) -> None:
|
|
"""Populate VM sizes for a location
|
|
:param location: location
|
|
"""
|
|
location = location.lower()
|
|
with self._vm_sizes_lock:
|
|
if location in self._queried_locations:
|
|
return
|
|
vmsizes = list(
|
|
self.service_proxy.compute_client.virtual_machine_sizes.list(
|
|
location)
|
|
)
|
|
with self._vm_sizes_lock:
|
|
for vmsize in vmsizes:
|
|
name = vmsize.name.lower()
|
|
if name in self._vm_sizes:
|
|
continue
|
|
self._vm_sizes[name] = vmsize
|
|
self._queried_locations.add(location)
|
|
|
|
def get_vm_size(
|
|
self,
|
|
vm_size: str
|
|
) -> 'azure.mgmt.compute.models.VirtualMachineSize':
|
|
"""Get VM Size information
|
|
:param vm_size: name of VM size
|
|
"""
|
|
with self._vm_sizes_lock:
|
|
return self._vm_sizes[vm_size.lower()]
|
|
|
|
|
|
class BatchServiceHandler():
|
|
def __init__(self, service_proxy: ServiceProxy) -> None:
|
|
"""Ctor for Federation Batch handler
|
|
:param service_proxy: ServiceProxy
|
|
"""
|
|
self.service_proxy = service_proxy
|
|
|
|
def get_pool_full_update(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
pool_id: str,
|
|
) -> batchmodels.CloudPool:
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
try:
|
|
return client.pool.get(pool_id)
|
|
except batchmodels.BatchErrorException:
|
|
pass
|
|
return None
|
|
|
|
def get_node_state_counts(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
pool_id: str,
|
|
) -> batchmodels.PoolNodeCounts:
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
try:
|
|
node_counts = client.account.list_pool_node_counts(
|
|
account_list_pool_node_counts_options=batchmodels.
|
|
AccountListPoolNodeCountsOptions(
|
|
filter='poolId eq \'{}\''.format(pool_id)
|
|
)
|
|
)
|
|
nc = list(node_counts)
|
|
if len(nc) == 0:
|
|
logger.error(
|
|
'no node counts for pool {} (account={} '
|
|
'service_url={})'.format(
|
|
pool_id, batch_account, service_url))
|
|
return nc[0]
|
|
except batchmodels.BatchErrorException:
|
|
logger.error(
|
|
'could not retrieve pool {} node counts (account={} '
|
|
'service_url={})'.format(pool_id, batch_account, service_url))
|
|
|
|
def immediately_evaluate_autoscale(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
pool_id: str,
|
|
) -> None:
|
|
# retrieve current autoscale
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
try:
|
|
pool = client.pool.get(pool_id)
|
|
if not pool.enable_auto_scale:
|
|
logger.warning(
|
|
'cannot immediately evaluate autoscale on pool {} as '
|
|
'autoscale is not enabled (batch_account={} '
|
|
'service_url={})'.format(
|
|
pool_id, batch_account, service_url))
|
|
return
|
|
client.pool.enable_auto_scale(
|
|
pool_id=pool.id,
|
|
auto_scale_formula=pool.auto_scale_formula,
|
|
auto_scale_evaluation_interval=pool.
|
|
auto_scale_evaluation_interval,
|
|
)
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
else:
|
|
logger.debug(
|
|
'autoscale enabled for pool {} interval={} (batch_account={} '
|
|
'service_url={})'.format(
|
|
pool_id, pool.auto_scale_evaluation_interval,
|
|
batch_account, service_url))
|
|
|
|
def add_job_schedule(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
jobschedule: batchmodels.JobScheduleAddParameter,
|
|
) -> None:
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
client.job_schedule.add(jobschedule)
|
|
|
|
def get_job(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
job_id: str,
|
|
) -> batchmodels.CloudJob:
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
return client.job.get(job_id)
|
|
|
|
def add_job(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
job: batchmodels.JobAddParameter,
|
|
) -> None:
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
client.job.add(job)
|
|
|
|
async def delete_or_terminate_job(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
job_id: str,
|
|
delete: bool,
|
|
is_job_schedule: bool,
|
|
wait: bool = False,
|
|
) -> None:
|
|
action = 'delete' if delete else 'terminate'
|
|
cstate = (
|
|
batchmodels.JobScheduleState.completed if is_job_schedule else
|
|
batchmodels.JobState.completed
|
|
)
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
iface = client.job_schedule if is_job_schedule else client.job
|
|
logger.debug('{} {} {} (account={} service_url={})'.format(
|
|
action, 'job schedule' if is_job_schedule else 'job',
|
|
job_id, batch_account, service_url))
|
|
try:
|
|
if delete:
|
|
iface.delete(job_id)
|
|
else:
|
|
iface.terminate(job_id)
|
|
except batchmodels.BatchErrorException as exc:
|
|
if delete:
|
|
if ('does not exist' in exc.message.value or
|
|
(not wait and
|
|
'marked for deletion' in exc.message.value)):
|
|
return
|
|
else:
|
|
if ('completed state' in exc.message.value or
|
|
'marked for deletion' in exc.message.value):
|
|
return
|
|
# wait for job to delete/terminate
|
|
if wait:
|
|
while True:
|
|
try:
|
|
_job = iface.get(job_id)
|
|
if _job.state == cstate:
|
|
break
|
|
except batchmodels.BatchErrorException as exc:
|
|
if 'does not exist' in exc.message.value:
|
|
break
|
|
else:
|
|
raise
|
|
await asyncio.sleep(1)
|
|
|
|
def _format_generic_task_id(
|
|
self, prefix: str, padding: int, tasknum: int) -> str:
|
|
"""Format a generic task id from a task number
|
|
:param prefix: prefix
|
|
:param padding: zfill task number
|
|
:param tasknum: task number
|
|
:return: generic task id
|
|
"""
|
|
return '{}{}'.format(prefix, str(tasknum).zfill(padding))
|
|
|
|
def regenerate_next_generic_task_id(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
job_id: str,
|
|
naming: TaskNaming,
|
|
current_task_id: str,
|
|
last_task_id: Optional[str] = None,
|
|
tasklist: Optional[List[str]] = None,
|
|
is_merge_task: Optional[bool] = False
|
|
) -> Tuple[List[str], str]:
|
|
"""Regenerate the next generic task id
|
|
:param batch_account: batch account
|
|
:param service_url: service url
|
|
:param job_id: job id
|
|
:param naming: naming convention
|
|
:param current_task_id: current task id
|
|
:param tasklist: list of committed and uncommitted tasks in job
|
|
:param is_merge_task: is merge task
|
|
:return: (list of task ids for job, next generic docker task id)
|
|
"""
|
|
# get prefix and padding settings
|
|
prefix = naming.prefix
|
|
if is_merge_task:
|
|
prefix = 'merge-{}'.format(prefix)
|
|
if not current_task_id.startswith(prefix):
|
|
return tasklist, current_task_id
|
|
delimiter = prefix if is_not_empty(prefix) else ' '
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
# get filtered, sorted list of generic docker task ids
|
|
try:
|
|
if tasklist is None:
|
|
tasklist = client.task.list(
|
|
job_id,
|
|
task_list_options=batchmodels.TaskListOptions(
|
|
filter='startswith(id, \'{}\')'.format(prefix)
|
|
if is_not_empty(prefix) else None,
|
|
select='id'))
|
|
tasklist = [x.id for x in tasklist]
|
|
tasknum = sorted(
|
|
[int(x.split(delimiter)[-1]) for x in tasklist])[-1] + 1
|
|
except (batchmodels.BatchErrorException, IndexError, TypeError):
|
|
tasknum = 0
|
|
id = self._format_generic_task_id(prefix, naming.padding, tasknum)
|
|
while id in tasklist:
|
|
try:
|
|
if (last_task_id is not None and
|
|
last_task_id.startswith(prefix)):
|
|
tasknum = int(last_task_id.split(delimiter)[-1])
|
|
last_task_id = None
|
|
except Exception:
|
|
last_task_id = None
|
|
tasknum += 1
|
|
id = self._format_generic_task_id(prefix, naming.padding, tasknum)
|
|
return tasklist, id
|
|
|
|
def _submit_task_sub_collection(
|
|
self,
|
|
client: azure.batch.BatchServiceClient,
|
|
job_id: str,
|
|
start: int,
|
|
end: int,
|
|
slice: int,
|
|
all_tasks: List[str],
|
|
task_map: Dict[str, batchmodels.TaskAddParameter]
|
|
) -> bool:
|
|
"""Submits a sub-collection of tasks, do not call directly
|
|
:param client: batch client
|
|
:param job_id: job to add to
|
|
:param start: start offset, includsive
|
|
:param end: end offset, exclusive
|
|
:param slice: slice width
|
|
:param all_tasks: list of all task ids
|
|
:param task_map: task collection map to add
|
|
"""
|
|
ret = True
|
|
initial_slice = slice
|
|
while True:
|
|
chunk_end = start + slice
|
|
if chunk_end > end:
|
|
chunk_end = end
|
|
chunk = all_tasks[start:chunk_end]
|
|
logger.debug('submitting {} tasks ({} -> {}) to job {}'.format(
|
|
len(chunk), start, chunk_end - 1, job_id))
|
|
try:
|
|
results = client.task.add_collection(job_id, chunk)
|
|
except batchmodels.BatchErrorException as e:
|
|
if e.error.code == 'RequestBodyTooLarge':
|
|
# collection contents are too large, reduce and retry
|
|
if slice == 1:
|
|
raise
|
|
slice = slice >> 1
|
|
if slice < 1:
|
|
slice = 1
|
|
logger.error(
|
|
('task collection slice was too big, retrying with '
|
|
'slice={}').format(slice))
|
|
continue
|
|
else:
|
|
# go through result and retry just failed tasks
|
|
while True:
|
|
retry = []
|
|
for result in results.value:
|
|
if (result.status ==
|
|
batchmodels.TaskAddStatus.client_error):
|
|
de = None
|
|
if result.error.values is not None:
|
|
de = [
|
|
'{}: {}'.format(x.key, x.value)
|
|
for x in result.error.values
|
|
]
|
|
logger.error(
|
|
('skipping retry of adding task {} as it '
|
|
'returned a client error (code={} '
|
|
'message={} {}) for job {}').format(
|
|
result.task_id, result.error.code,
|
|
result.error.message,
|
|
' '.join(de) if de is not None else '',
|
|
job_id))
|
|
ret = False
|
|
elif (result.status ==
|
|
batchmodels.TaskAddStatus.server_error):
|
|
retry.append(task_map[result.task_id])
|
|
if len(retry) > 0:
|
|
logger.debug(
|
|
'retrying adding {} tasks to job {}'.format(
|
|
len(retry), job_id))
|
|
results = client.task.add_collection(job_id, retry)
|
|
else:
|
|
break
|
|
if chunk_end == end:
|
|
break
|
|
start = chunk_end
|
|
slice = initial_slice
|
|
return ret
|
|
|
|
def add_task_collection(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
job_id: str,
|
|
task_map: Dict[str, batchmodels.TaskAddParameter]
|
|
) -> None:
|
|
"""Add a collection of tasks to a job
|
|
:param batch_account: batch account
|
|
:param service_url: service url
|
|
:param job_id: job to add to
|
|
:param task_map: task collection map to add
|
|
"""
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
all_tasks = list(task_map.values())
|
|
num_tasks = len(all_tasks)
|
|
if num_tasks == 0:
|
|
logger.debug(
|
|
'no tasks detected in task_map for job {} for '
|
|
'(batch_account={} service-url={})'.format(
|
|
job_id, batch_account, service_url))
|
|
return
|
|
slice = 100 # can only submit up to 100 tasks at a time
|
|
task_futures = []
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=_MAX_EXECUTOR_WORKERS) as executor:
|
|
for start in range(0, num_tasks, slice):
|
|
end = start + slice
|
|
if end > num_tasks:
|
|
end = num_tasks
|
|
task_futures.append(executor.submit(
|
|
self._submit_task_sub_collection, client, job_id, start,
|
|
end, end - start, all_tasks, task_map))
|
|
# throw exceptions from any failure
|
|
try:
|
|
errors = any(not x.result() for x in task_futures)
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
errors = True
|
|
if errors:
|
|
logger.error(
|
|
'failures detected in task submission of {} tasks for '
|
|
'job {} for (batch_account={} service_url={})'.format(
|
|
num_tasks, job_id, batch_account, service_url))
|
|
else:
|
|
logger.info(
|
|
'submitted all {} tasks to job {} for (batch_account={} '
|
|
'service_url={})'.format(
|
|
num_tasks, job_id, batch_account, service_url))
|
|
|
|
def set_auto_complete_on_job(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
job_id: str
|
|
) -> None:
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
client.job.patch(
|
|
job_id=job_id,
|
|
job_patch_parameter=batchmodels.JobPatchParameter(
|
|
on_all_tasks_complete=batchmodels.
|
|
OnAllTasksComplete.terminate_job
|
|
),
|
|
)
|
|
logger.debug('set auto-completion for job {}'.format(job_id))
|
|
|
|
def aggregate_active_task_count_on_pool(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
pool_id: str,
|
|
) -> int:
|
|
total_active = 0
|
|
client = self.service_proxy.batch_client(batch_account, service_url)
|
|
try:
|
|
jobs = list(client.job.list(
|
|
job_list_options=batchmodels.JobListOptions(
|
|
filter='(state eq \'active\') and (executionInfo/poolId '
|
|
'eq \'{}\')'.format(pool_id),
|
|
select='id',
|
|
),
|
|
))
|
|
except batchmodels.BatchErrorException as exc:
|
|
logger.exception(str(exc))
|
|
else:
|
|
if len(jobs) == 0:
|
|
return total_active
|
|
tc_futures = []
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers_for_executor(jobs)) as executor:
|
|
for job in jobs:
|
|
tc_futures.append(executor.submit(
|
|
client.job.get_task_counts, job.id))
|
|
for tc in tc_futures:
|
|
try:
|
|
total_active += tc.result().active
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
return total_active
|
|
|
|
|
|
class FederationDataHandler():
|
|
_GLOBAL_LOCK_BLOB = 'global.lock'
|
|
_ALL_FEDERATIONS_PK = '!!FEDERATIONS'
|
|
_FEDERATION_ACTIONS_PREFIX_PK = '!!ACTIONS'
|
|
_BLOCKED_FEDERATION_ACTIONS_PREFIX_PK = '!!ACTIONS.BLOCKED'
|
|
_MAX_SEQUENCE_ID_PROPERTIES = 15
|
|
_MAX_SEQUENCE_IDS_PER_PROPERTY = 975
|
|
_MAX_STR_ENTITY_PROPERTY_LENGTH = 32174
|
|
|
|
def __init__(self, service_proxy: ServiceProxy) -> None:
|
|
"""Ctor for Federation data handler
|
|
:param service_proxy: ServiceProxy
|
|
"""
|
|
self.service_proxy = service_proxy
|
|
self.lease_id = None
|
|
try:
|
|
self.scheduling_blackout = int(
|
|
self.service_proxy._config[
|
|
'scheduling']['after_success']['blackout_interval'])
|
|
except KeyError:
|
|
self.scheduling_blackout = 15
|
|
try:
|
|
self.scheduling_evaluate_autoscale = self.service_proxy._config[
|
|
'scheduling']['after_success']['evaluate_autoscale']
|
|
except KeyError:
|
|
self.scheduling_evaluate_autoscale = True
|
|
|
|
@property
|
|
def has_global_lock(self) -> bool:
|
|
return self.lease_id is not None
|
|
|
|
def lease_global_lock(
|
|
self,
|
|
loop: asyncio.BaseEventLoop,
|
|
) -> None:
|
|
try:
|
|
if self.lease_id is None:
|
|
logger.debug('acquiring blob lease on {}'.format(
|
|
self._GLOBAL_LOCK_BLOB))
|
|
self.lease_id = \
|
|
self.service_proxy.blob_client.acquire_blob_lease(
|
|
self.service_proxy.blob_container_name_global,
|
|
self._GLOBAL_LOCK_BLOB, lease_duration=15)
|
|
logger.debug('blob lease acquired on {}'.format(
|
|
self._GLOBAL_LOCK_BLOB))
|
|
else:
|
|
self.lease_id = \
|
|
self.service_proxy.blob_client.renew_blob_lease(
|
|
self.service_proxy.blob_container_name_global,
|
|
self._GLOBAL_LOCK_BLOB, self.lease_id)
|
|
except Exception:
|
|
self.lease_id = None
|
|
if self.lease_id is None:
|
|
logger.error('could not acquire/renew lease on {}'.format(
|
|
self._GLOBAL_LOCK_BLOB))
|
|
loop.call_later(5, self.lease_global_lock, loop)
|
|
|
|
def release_global_lock(self) -> None:
|
|
if self.lease_id is not None:
|
|
try:
|
|
self.service_proxy.blob_client.release_blob_lease(
|
|
self.service_proxy.blob_container_name_global,
|
|
self._GLOBAL_LOCK_BLOB, self.lease_id)
|
|
except azure.common.AzureConflictHttpError:
|
|
self.lease_id = None
|
|
|
|
def mount_file_storage(self) -> Optional[pathlib.Path]:
|
|
if not self.service_proxy.logger_persist:
|
|
logger.warning('logging persistence is disabled')
|
|
return None
|
|
# create logs directory
|
|
log_path = self.service_proxy.batch_shipyard_var_path / 'logs'
|
|
log_path.mkdir(exist_ok=True)
|
|
# mount
|
|
cmd = (
|
|
'mount -t cifs //{sa}.file.{ep}/{share} {hmp} -o '
|
|
'vers=3.0,username={sa},password={sakey},_netdev,serverino'
|
|
).format(
|
|
sa=self.service_proxy.creds.storage_account,
|
|
ep=self.service_proxy.creds.storage_account_ep,
|
|
share=self.service_proxy.file_share_logging,
|
|
hmp=log_path,
|
|
sakey=self.service_proxy.creds.storage_account_key,
|
|
)
|
|
logger.debug('attempting to mount file share for logging persistence')
|
|
try:
|
|
output = subprocess.check_output(
|
|
cmd, shell=True, stderr=subprocess.PIPE)
|
|
except subprocess.CalledProcessError as exc:
|
|
logger.error('subprocess run error: {} exited with {}'.format(
|
|
exc.cmd, exc.returncode))
|
|
logger.error('stderr: {}'.format(exc.stderr))
|
|
logger.error('stdout: {}'.format(exc.stdout))
|
|
raise
|
|
else:
|
|
logger.debug(output)
|
|
return log_path
|
|
|
|
def unmount_file_storage(self) -> None:
|
|
if not self.service_proxy.logger_persist:
|
|
return
|
|
log_path = self.service_proxy.batch_shipyard_var_path / 'logs'
|
|
cmd = 'umount {hmp}'.format(hmp=log_path)
|
|
logger.debug(
|
|
'attempting to unmount file share for logging persistence')
|
|
output = subprocess.check_output(
|
|
cmd, shell=True, stderr=subprocess.PIPE)
|
|
logger.debug(output)
|
|
|
|
def set_log_configuration(self, log_path: pathlib.Path) -> None:
|
|
global logger
|
|
# remove existing handlers
|
|
handlers = logger.handlers[:]
|
|
for handler in handlers:
|
|
handler.close()
|
|
logger.removeHandler(handler)
|
|
# set level
|
|
if self.service_proxy.logger_level == 'info':
|
|
logger.setLevel(logging.INFO)
|
|
elif self.service_proxy.logger_level == 'warning':
|
|
logger.setLevel(logging.WARNING)
|
|
elif self.service_proxy.logger_level == 'error':
|
|
logger.setLevel(logging.ERROR)
|
|
elif self.service_proxy.logger_level == 'critical':
|
|
logger.setLevel(logging.CRITICAL)
|
|
else:
|
|
logger.setLevel(logging.DEBUG)
|
|
# set formatter
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s %(levelname)s %(name)s:%(funcName)s:%(lineno)d '
|
|
'%(message)s')
|
|
formatter.default_msec_format = '%s.%03d'
|
|
# set handlers
|
|
handler_stream = logging.StreamHandler()
|
|
handler_stream.setFormatter(formatter)
|
|
logger.addHandler(handler_stream)
|
|
az_storage_logger = logging.getLogger('azure.storage')
|
|
az_storage_logger.setLevel(logging.WARNING)
|
|
az_storage_logger.addHandler(handler_stream)
|
|
az_cosmosdb_logger = logging.getLogger('azure.cosmosdb')
|
|
az_cosmosdb_logger.setLevel(logging.WARNING)
|
|
az_cosmosdb_logger.addHandler(handler_stream)
|
|
# set log file
|
|
if log_path is None:
|
|
logger.warning('not setting logfile as persistence is disabled')
|
|
else:
|
|
# log to selected log level file
|
|
logfname = pathlib.Path(self.service_proxy.logger_filename)
|
|
logfile = log_path / '{}-{}{}'.format(
|
|
logfname.stem, self.service_proxy.logger_level,
|
|
logfname.suffix)
|
|
logfile.parent.mkdir(exist_ok=True)
|
|
handler_logfile = logging.handlers.RotatingFileHandler(
|
|
str(logfile), maxBytes=33554432, backupCount=20000,
|
|
encoding='utf-8')
|
|
handler_logfile.setFormatter(formatter)
|
|
logger.addHandler(handler_logfile)
|
|
az_storage_logger.addHandler(handler_logfile)
|
|
az_cosmosdb_logger.addHandler(handler_logfile)
|
|
# always log to error file
|
|
if self.service_proxy.logger_level != 'error':
|
|
logfile_err = log_path / '{}-error{}'.format(
|
|
logfname.stem, logfname.suffix)
|
|
logfile_err.parent.mkdir(exist_ok=True)
|
|
handler_logfile_err = logging.handlers.RotatingFileHandler(
|
|
str(logfile_err), maxBytes=33554432, backupCount=10000,
|
|
encoding='utf-8')
|
|
handler_logfile_err.setFormatter(formatter)
|
|
handler_logfile_err.setLevel(logging.ERROR)
|
|
logger.addHandler(handler_logfile_err)
|
|
|
|
def get_all_federations(self) -> List[azure.cosmosdb.table.Entity]:
|
|
"""Get all federations"""
|
|
return self.service_proxy.table_client.query_entities(
|
|
self.service_proxy.table_name_global,
|
|
filter='PartitionKey eq \'{}\''.format(self._ALL_FEDERATIONS_PK))
|
|
|
|
def get_all_pools_for_federation(
|
|
self,
|
|
fedhash: str
|
|
) -> List[azure.cosmosdb.table.Entity]:
|
|
"""Get all pools for a federation
|
|
:param fedhash: federation hash
|
|
"""
|
|
return self.service_proxy.table_client.query_entities(
|
|
self.service_proxy.table_name_global,
|
|
filter='PartitionKey eq \'{}\''.format(fedhash))
|
|
|
|
def get_pool_for_federation(
|
|
self,
|
|
fedhash: str,
|
|
poolhash: str,
|
|
) -> Optional[azure.cosmosdb.table.Entity]:
|
|
try:
|
|
return self.service_proxy.table_client.get_entity(
|
|
self.service_proxy.table_name_global, fedhash, poolhash)
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
return None
|
|
|
|
def generate_pk_rk_for_job_location_entity(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
pool: 'FederationPool',
|
|
) -> Tuple[str, str]:
|
|
pk = '{}${}'.format(fedhash, hash_string(job_id))
|
|
rk = hash_string('{}${}'.format(pool.service_url, pool.pool_id))
|
|
return pk, rk
|
|
|
|
def get_location_entity_for_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
pool: 'FederationPool',
|
|
) -> Optional[azure.cosmosdb.table.Entity]:
|
|
pk, rk = self.generate_pk_rk_for_job_location_entity(
|
|
fedhash, job_id, pool)
|
|
try:
|
|
return self.service_proxy.table_client.get_entity(
|
|
self.service_proxy.table_name_jobs, pk, rk)
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
return None
|
|
|
|
def location_entities_exist_for_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
) -> bool:
|
|
try:
|
|
entities = self.service_proxy.table_client.query_entities(
|
|
self.service_proxy.table_name_jobs,
|
|
filter='PartitionKey eq \'{}${}\''.format(
|
|
fedhash, hash_string(job_id))
|
|
)
|
|
for ent in entities:
|
|
return True
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
pass
|
|
return False
|
|
|
|
def insert_or_update_entity_with_etag_for_job(
|
|
self,
|
|
entity: Dict[str, Any],
|
|
) -> bool:
|
|
if 'etag' not in entity:
|
|
try:
|
|
self.service_proxy.table_client.insert_entity(
|
|
self.service_proxy.table_name_jobs, entity=entity)
|
|
return True
|
|
except azure.common.AzureConflictHttpError:
|
|
pass
|
|
else:
|
|
etag = entity['etag']
|
|
entity.pop('etag')
|
|
try:
|
|
self.service_proxy.table_client.update_entity(
|
|
self.service_proxy.table_name_jobs, entity=entity,
|
|
if_match=etag)
|
|
return True
|
|
except azure.common.AzureConflictHttpError:
|
|
pass
|
|
except azure.common.AzureHttpError as ex:
|
|
if ex.status_code != 412:
|
|
raise
|
|
return False
|
|
|
|
def delete_location_entity_for_job(
|
|
self,
|
|
entity: Dict[str, Any],
|
|
) -> None:
|
|
try:
|
|
self.service_proxy.table_client.delete_entity(
|
|
self.service_proxy.table_name_jobs, entity['PartitionKey'],
|
|
entity['RowKey'])
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
pass
|
|
|
|
def get_all_location_entities_for_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
) -> Optional[List[azure.cosmosdb.table.Entity]]:
|
|
try:
|
|
return self.service_proxy.table_client.query_entities(
|
|
self.service_proxy.table_name_jobs,
|
|
filter='PartitionKey eq \'{}${}\''.format(
|
|
fedhash, hash_string(job_id))
|
|
)
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
return None
|
|
|
|
def delete_action_entity_for_job(
|
|
self,
|
|
entity: Dict[str, Any],
|
|
) -> None:
|
|
try:
|
|
self.service_proxy.table_client.delete_entity(
|
|
self.service_proxy.table_name_jobs, entity['PartitionKey'],
|
|
entity['RowKey'], if_match=entity['etag'])
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
pass
|
|
|
|
def get_messages_from_federation_queue(
|
|
self,
|
|
fedhash: str
|
|
) -> List[azure.storage.queue.models.QueueMessage]:
|
|
queue_name = '{}-{}'.format(
|
|
self.service_proxy.queue_prefix, fedhash)
|
|
return self.service_proxy.queue_client.get_messages(
|
|
queue_name, num_messages=32, visibility_timeout=1)
|
|
|
|
def _get_sequence_entity_for_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str
|
|
) -> azure.cosmosdb.table.Entity:
|
|
return self.service_proxy.table_client.get_entity(
|
|
self.service_proxy.table_name_jobs,
|
|
'{}${}'.format(self._FEDERATION_ACTIONS_PREFIX_PK, fedhash),
|
|
hash_string(job_id))
|
|
|
|
def get_first_sequence_id_for_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str
|
|
) -> str:
|
|
try:
|
|
entity = self._get_sequence_entity_for_job(fedhash, job_id)
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
return None
|
|
else:
|
|
try:
|
|
return entity['Sequence0'].split(',')[0]
|
|
except Exception:
|
|
return None
|
|
|
|
def pop_and_pack_sequence_ids_for_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
) -> azure.cosmosdb.table.Entity:
|
|
entity = self._get_sequence_entity_for_job(fedhash, job_id)
|
|
seq = []
|
|
for i in range(0, self._MAX_SEQUENCE_ID_PROPERTIES):
|
|
prop = 'Sequence{}'.format(i)
|
|
if prop in entity and is_not_empty(entity[prop]):
|
|
seq.extend(entity[prop].split(','))
|
|
seq.pop(0)
|
|
for i in range(0, self._MAX_SEQUENCE_ID_PROPERTIES):
|
|
prop = 'Sequence{}'.format(i)
|
|
start = i * self._MAX_SEQUENCE_IDS_PER_PROPERTY
|
|
end = start + self._MAX_SEQUENCE_IDS_PER_PROPERTY
|
|
if end > len(seq):
|
|
end = len(seq)
|
|
if start < end:
|
|
entity[prop] = ','.join(seq[start:end])
|
|
else:
|
|
entity[prop] = None
|
|
return entity, len(seq) == 0
|
|
|
|
def dequeue_sequence_id_from_federation_sequence(
|
|
self,
|
|
delete_message: bool,
|
|
fedhash: str,
|
|
msg_id: str,
|
|
pop_receipt: str,
|
|
target: str,
|
|
) -> None:
|
|
# pop first item off table sequence
|
|
if is_not_empty(target):
|
|
while True:
|
|
entity, empty_seq = self.pop_and_pack_sequence_ids_for_job(
|
|
fedhash, target)
|
|
# see if there are no job location entities
|
|
if (empty_seq and not self.location_entities_exist_for_job(
|
|
fedhash, target)):
|
|
# delete entity
|
|
self.delete_action_entity_for_job(entity)
|
|
logger.debug(
|
|
'deleted target {} action entity from '
|
|
'federation {}'.format(target, fedhash))
|
|
break
|
|
else:
|
|
# merge update
|
|
if self.insert_or_update_entity_with_etag_for_job(
|
|
entity):
|
|
logger.debug(
|
|
'upserted target {} sequence to '
|
|
'federation {}'.format(target, fedhash))
|
|
break
|
|
else:
|
|
logger.debug(
|
|
'conflict upserting target {} sequence to '
|
|
'federation {}'.format(target, fedhash))
|
|
# dequeue message
|
|
if delete_message:
|
|
queue_name = '{}-{}'.format(
|
|
self.service_proxy.queue_prefix, fedhash)
|
|
self.service_proxy.queue_client.delete_message(
|
|
queue_name, msg_id, pop_receipt)
|
|
|
|
def add_blocked_action_for_job(
|
|
self,
|
|
fedhash: str,
|
|
target: str,
|
|
unique_id: str,
|
|
num_tasks: int,
|
|
reason: str,
|
|
) -> None:
|
|
entity = {
|
|
'PartitionKey': '{}${}'.format(
|
|
self._BLOCKED_FEDERATION_ACTIONS_PREFIX_PK, fedhash),
|
|
'RowKey': hash_string(target),
|
|
'UniqueId': unique_id,
|
|
'Id': target,
|
|
'NumTasks': num_tasks,
|
|
'Reason': reason,
|
|
}
|
|
self.service_proxy.table_client.insert_or_replace_entity(
|
|
self.service_proxy.table_name_jobs, entity)
|
|
|
|
def remove_blocked_action_for_job(
|
|
self,
|
|
fedhash: str,
|
|
target: str
|
|
) -> None:
|
|
pk = '{}${}'.format(
|
|
self._BLOCKED_FEDERATION_ACTIONS_PREFIX_PK, fedhash)
|
|
rk = hash_string(target)
|
|
try:
|
|
self.service_proxy.table_client.delete_entity(
|
|
self.service_proxy.table_name_jobs, pk, rk)
|
|
except azure.common.AzureMissingResourceHttpError:
|
|
pass
|
|
|
|
def _create_blob_client(self, sa, ep, sas):
|
|
return azure.storage.blob.BlockBlobService(
|
|
account_name=sa,
|
|
sas_token=sas,
|
|
endpoint_suffix=ep
|
|
)
|
|
|
|
def construct_blob_url(
|
|
self,
|
|
fedhash: str,
|
|
unique_id: str
|
|
) -> str:
|
|
return (
|
|
'https://{sa}.blob.{ep}/{prefix}-{fedhash}/messages/{uid}.pickle'
|
|
).format(
|
|
sa=self.service_proxy.creds.storage_account,
|
|
ep=self.service_proxy.creds.storage_account_ep,
|
|
prefix=self.service_proxy.blob_container_data_prefix,
|
|
fedhash=fedhash,
|
|
uid=unique_id
|
|
)
|
|
|
|
def retrieve_blob_data(
|
|
self,
|
|
url: str
|
|
) -> Tuple[azure.storage.blob.BlockBlobService, str, str, bytes]:
|
|
"""Retrieve a blob URL
|
|
:param url: Azure Storage url to retrieve
|
|
:return: blob client, container, blob name, data
|
|
"""
|
|
# explode url into parts
|
|
tmp = url.split('/')
|
|
host = tmp[2].split('.')
|
|
sa = host[0]
|
|
ep = '.'.join(host[2:])
|
|
del host
|
|
tmp = '/'.join(tmp[3:]).split('?')
|
|
if len(tmp) > 1:
|
|
sas = tmp[1]
|
|
else:
|
|
sas = None
|
|
tmp = tmp[0].split('/')
|
|
container = tmp[0]
|
|
blob_name = '/'.join(tmp[1:])
|
|
del tmp
|
|
if sas is not None:
|
|
blob_client = self._create_blob_client(sa, ep, sas)
|
|
else:
|
|
blob_client = self.service_proxy.blob_client
|
|
data = blob_client.get_blob_to_bytes(container, blob_name)
|
|
return blob_client, container, blob_name, data.content
|
|
|
|
def delete_blob(
|
|
self,
|
|
blob_client: azure.storage.blob.BlockBlobService,
|
|
container: str,
|
|
blob_name: str
|
|
) -> None:
|
|
blob_client.delete_blob(container, blob_name)
|
|
|
|
|
|
class FederationPool():
|
|
def __init__(
|
|
self,
|
|
batch_account: str,
|
|
service_url: str,
|
|
location: str,
|
|
pool_id: str,
|
|
cloud_pool: batchmodels.CloudPool,
|
|
vm_size: 'azure.mgmt.compute.models.VirtualMachineSize'
|
|
) -> None:
|
|
self._vm_size = None # type: str
|
|
self._native = None # type: bool
|
|
self._cloud_pool = None # type: batchmodels.CloudPool
|
|
self._pool_last_update = None # type: datetime.datetime
|
|
self._node_counts = None # type: batchmodels.PoolNodeCounts
|
|
self._node_counts_last_update = None # type: datetime.datetime
|
|
self._blackout_end_time = datetime_utcnow(as_string=False)
|
|
self._active_tasks_count = None # type: int
|
|
self._active_tasks_count_last_update = None # type: datetime.datetime
|
|
self.batch_account = batch_account
|
|
self.service_url = service_url
|
|
self.location = location.lower()
|
|
self.pool_id = pool_id
|
|
self.cloud_pool = cloud_pool
|
|
self.vm_props = vm_size
|
|
if self.is_valid:
|
|
self._vm_size = self.cloud_pool.vm_size.lower()
|
|
|
|
@property
|
|
def cloud_pool(self) -> batchmodels.CloudPool:
|
|
return self._cloud_pool
|
|
|
|
@cloud_pool.setter
|
|
def cloud_pool(self, value: batchmodels.CloudPool) -> None:
|
|
self._cloud_pool = value
|
|
if (self._cloud_pool is not None and
|
|
is_not_empty(self._cloud_pool.metadata)):
|
|
for md in self._cloud_pool.metadata:
|
|
if md.name == _POOL_NATIVE_METADATA_NAME:
|
|
self.native = md.value == '1'
|
|
self._last_update = datetime_utcnow(as_string=False)
|
|
|
|
@property
|
|
def native(self) -> bool:
|
|
return self._native
|
|
|
|
@native.setter
|
|
def native(self, value: bool) -> None:
|
|
self._native = value
|
|
|
|
@property
|
|
def node_counts(self) -> batchmodels.PoolNodeCounts:
|
|
return self._node_counts
|
|
|
|
@node_counts.setter
|
|
def node_counts(self, value: batchmodels.PoolNodeCounts) -> None:
|
|
self._node_counts = value
|
|
self._node_counts_last_update = datetime_utcnow(as_string=False)
|
|
|
|
@property
|
|
def active_tasks_count(self) -> int:
|
|
return self._active_tasks_count
|
|
|
|
@active_tasks_count.setter
|
|
def active_tasks_count(self, value: int) -> None:
|
|
self._active_tasks_count = value
|
|
self._active_tasks_count_last_update = datetime_utcnow(as_string=False)
|
|
|
|
@property
|
|
def is_valid(self) -> bool:
|
|
if (self.cloud_pool is not None and self.vm_props is not None and
|
|
datetime_utcnow(as_string=False) > self._blackout_end_time):
|
|
return self.cloud_pool.state == batchmodels.PoolState.active
|
|
return False
|
|
|
|
@property
|
|
def pool_requires_update(self) -> bool:
|
|
return (
|
|
not self.is_valid or self._pool_last_update is None or
|
|
(datetime_utcnow() - self._pool_last_update) >
|
|
_MAX_TIMESPAN_POOL_UPDATE
|
|
)
|
|
|
|
@property
|
|
def node_counts_requires_update(self) -> bool:
|
|
if not self.is_valid:
|
|
return False
|
|
return (
|
|
self._node_counts_last_update is None or
|
|
(datetime_utcnow() - self._node_counts_last_update) >
|
|
_MAX_TIMESPAN_NODE_COUNTS_UPDATE
|
|
)
|
|
|
|
@property
|
|
def active_tasks_count_requires_update(self) -> bool:
|
|
if not self.is_valid:
|
|
return False
|
|
return (
|
|
self._active_tasks_count_last_update is None or
|
|
(datetime_utcnow() - self._active_tasks_count_last_update) >
|
|
_MAX_TIMESPAN_ACTIVE_TASKS_COUNT_UPDATE
|
|
)
|
|
|
|
@property
|
|
def schedulable_low_priority_nodes(self) -> Optional[int]:
|
|
if not self.is_valid or self.node_counts is None:
|
|
return None
|
|
return (self.node_counts.low_priority.idle +
|
|
self.node_counts.low_priority.running)
|
|
|
|
@property
|
|
def schedulable_dedicated_nodes(self) -> Optional[int]:
|
|
if not self.is_valid or self.node_counts is None:
|
|
return None
|
|
return (self.node_counts.dedicated.idle +
|
|
self.node_counts.dedicated.running)
|
|
|
|
@property
|
|
def vm_size(self) -> Optional[str]:
|
|
return self._vm_size
|
|
|
|
def has_registry_login(self, registry: str) -> bool:
|
|
if not self.is_valid:
|
|
return None
|
|
if self.native:
|
|
cc = self._cloud_pool.virtual_machine_configuration.\
|
|
container_configuration
|
|
if cc.container_registries is None:
|
|
return None
|
|
for cr in cc.container_registries:
|
|
if is_none_or_empty(cr.registry_server):
|
|
cmpr = 'dockerhub-{}'.format(cr.user_name)
|
|
else:
|
|
cmpr = '{}-{}'.format(cr.registry_server, cr.user_name)
|
|
if cmpr == registry:
|
|
return True
|
|
else:
|
|
if self._cloud_pool.start_task is None:
|
|
return None
|
|
creds = {}
|
|
for ev in self._cloud_pool.start_task.environment_settings:
|
|
if (ev.name.startswith('DOCKER_LOGIN_') and
|
|
ev.name != 'DOCKER_LOGIN_PASSWORD'):
|
|
creds[ev.name] = ev.value.split(',')
|
|
if len(creds) == 2:
|
|
break
|
|
logins = set()
|
|
print(creds)
|
|
if len(creds) > 0:
|
|
for i in range(0, len(creds['DOCKER_LOGIN_USERNAME'])):
|
|
srv = creds['DOCKER_LOGIN_SERVER'][i]
|
|
if is_none_or_empty(srv):
|
|
srv = 'dockerhub'
|
|
logins.add('{}-{}'.format(
|
|
srv, creds['DOCKER_LOGIN_USERNAME'][i]))
|
|
if registry in logins:
|
|
return True
|
|
return False
|
|
|
|
def on_new_tasks_scheduled(
|
|
self,
|
|
bsh: BatchServiceHandler,
|
|
blackout: int,
|
|
evaluate_as: bool
|
|
) -> None:
|
|
# invalidate count caches
|
|
self._node_counts_last_update = None
|
|
self._active_tasks_count_last_update = None
|
|
# set scheduling blackout time
|
|
if blackout > 0:
|
|
self._blackout_end_time = datetime_utcnow(
|
|
as_string=False) + datetime.timedelta(seconds=blackout)
|
|
logger.debug(
|
|
'blackout time for pool {} updated to {} (batch_account={} '
|
|
'service_url={})'.format(
|
|
self.pool_id, self._blackout_end_time, self.batch_account,
|
|
self.service_url))
|
|
# evaluate autoscale now
|
|
if (evaluate_as and self.cloud_pool is not None and
|
|
self.cloud_pool.enable_auto_scale):
|
|
bsh.immediately_evaluate_autoscale(
|
|
self.batch_account, self.service_url, self.pool_id)
|
|
|
|
|
|
class Federation():
|
|
def __init__(self, fedhash: str, fedid: str) -> None:
|
|
self.lock = threading.Lock()
|
|
self.hash = fedhash
|
|
self.id = fedid
|
|
self.pools = {} # type: Dict[str, FederationPool]
|
|
|
|
def update_pool(
|
|
self,
|
|
csh: ComputeServiceHandler,
|
|
bsh: BatchServiceHandler,
|
|
entity: azure.cosmosdb.table.Entity,
|
|
poolset: set,
|
|
) -> str:
|
|
rk = entity['RowKey']
|
|
exists = False
|
|
with self.lock:
|
|
if rk in self.pools:
|
|
exists = True
|
|
if self.pools[rk].is_valid:
|
|
poolset.add(rk)
|
|
return rk
|
|
batch_account = entity['BatchAccount']
|
|
poolid = entity['PoolId']
|
|
service_url = entity['BatchServiceUrl']
|
|
pool = bsh.get_pool_full_update(
|
|
batch_account, service_url, poolid)
|
|
if exists and pool is not None:
|
|
with self.lock:
|
|
self.pools[rk].cloud_pool = pool
|
|
poolset.add(rk)
|
|
return rk
|
|
location = entity['Location']
|
|
csh.populate_vm_sizes_from_location(location)
|
|
vm_size = None
|
|
if pool is not None:
|
|
vm_size = csh.get_vm_size(pool.vm_size)
|
|
fedpool = FederationPool(
|
|
batch_account, service_url, location, poolid, pool, vm_size
|
|
)
|
|
with self.lock:
|
|
poolset.add(rk)
|
|
self.pools[rk] = fedpool
|
|
if self.pools[rk].is_valid:
|
|
logger.info(
|
|
'valid pool {} id={} to federation {} id={} for '
|
|
'account {} at location {} size={} ppn={} mem={}'.format(
|
|
rk, poolid, self.hash, self.id, batch_account,
|
|
location, fedpool.vm_size,
|
|
fedpool.vm_props.number_of_cores,
|
|
fedpool.vm_props.memory_in_mb))
|
|
elif not exists:
|
|
logger.warning(
|
|
'invalid pool {} id={} to federation {} '
|
|
'id={} (batch_account={} service_url={})'.format(
|
|
rk, poolid, self.hash, self.id, fedpool.batch_account,
|
|
fedpool.service_url))
|
|
return rk
|
|
|
|
def trim_orphaned_pools(self, fedpools: set) -> None:
|
|
with self.lock:
|
|
# do not get symmetric difference
|
|
diff = [x for x in self.pools.keys() if x not in fedpools]
|
|
removed = False
|
|
for rk in diff:
|
|
logger.debug(
|
|
'removing pool {} id={} from federation {} id={}'.format(
|
|
rk, self.pools[rk].pool_id, self.hash, self.id))
|
|
self.pools.pop(rk)
|
|
removed = True
|
|
if removed:
|
|
pool_ids = [self.pools[x].pool_id for x in self.pools]
|
|
logger.info('active pools in federation {} id={}: {}'.format(
|
|
self.hash, self.id, ' '.join(pool_ids)))
|
|
|
|
def check_pool_in_federation(
|
|
self,
|
|
fdh: FederationDataHandler,
|
|
poolhash: str
|
|
) -> bool:
|
|
entity = fdh.get_pool_for_federation(self.hash, poolhash)
|
|
return entity is not None
|
|
|
|
def _log_constraint_failure(
|
|
self,
|
|
unique_id: str,
|
|
pool_id: str,
|
|
constraint_name: str,
|
|
required_value: Any,
|
|
actual_value: Any,
|
|
) -> None:
|
|
logger.debug(
|
|
'constraint failure for uid {} on pool {} for fed id {} '
|
|
'fed hash {}: {} requires {} actual {}'.format(
|
|
unique_id, pool_id, self.id, self.hash, constraint_name,
|
|
required_value, actual_value)
|
|
)
|
|
|
|
def _filter_pool_with_hard_constraints(
|
|
self,
|
|
pool: FederationPool,
|
|
constraints: Constraints,
|
|
unique_id: str,
|
|
) -> bool:
|
|
# constraint order matching
|
|
# 0. pool validity
|
|
# 1. location
|
|
# 2. virtual network arm id
|
|
# 3. custom image arm id
|
|
# 4. windows (implies native)
|
|
# 5. native
|
|
# 6. autoscale disallow
|
|
# 7. autoscale exclusive
|
|
# 8. low priority disallow
|
|
# 9. low priority exclusive
|
|
# 10. exclusive
|
|
# 11. vm_size
|
|
# 12. gpu
|
|
# 13. infiniband
|
|
# 14. cores
|
|
# 15. memory
|
|
# 16. multi instance -> inter node
|
|
# 17. registries
|
|
|
|
cp = pool.cloud_pool
|
|
# pool validity (this function shouldn't be called with invalid
|
|
# pools, but check anyways)
|
|
if not pool.is_valid:
|
|
logger.debug(
|
|
'pool {} is not valid for filtering of uid {} for fed id {} '
|
|
'fed hash {}'.format(cp.id, unique_id, self.id, self.hash))
|
|
return True
|
|
# location
|
|
if (is_not_empty(constraints.pool.location) and
|
|
constraints.pool.location != pool.location):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'location', constraints.pool.location,
|
|
pool.location)
|
|
return True
|
|
# virtual network
|
|
if (is_not_empty(constraints.pool.virtual_network_arm_id) and
|
|
(cp.network_configuration is None or
|
|
constraints.pool.virtual_network_arm_id !=
|
|
cp.network_configuration.subnet_id.lower())):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'virtual_network_arm_id',
|
|
constraints.pool.virtual_network_arm_id,
|
|
cp.network_configuration.subnet_id
|
|
if cp.network_configuration is not None else 'none')
|
|
return True
|
|
# custom image
|
|
if (is_not_empty(constraints.pool.custom_image_arm_id) and
|
|
(cp.virtual_machine_configuration is None or
|
|
constraints.pool.custom_image_arm_id !=
|
|
cp.virtual_machine_configuration.image_reference.
|
|
virtual_machine_image_id.lower())):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'custom_image_arm_id',
|
|
constraints.pool.custom_image_arm_id,
|
|
cp.virtual_machine_configuration.image_reference.
|
|
virtual_machine_image_id
|
|
if cp.virtual_machine_configuration is not None else 'none')
|
|
return True
|
|
# windows
|
|
if (constraints.pool.windows and
|
|
(cp.virtual_machine_configuration is None or
|
|
not cp.virtual_machine_configuration.
|
|
node_agent_sku_id.lower().startswith('batch.node.windows'))):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'windows',
|
|
constraints.pool.windows,
|
|
cp.virtual_machine_configuration.node_agent_sku_id)
|
|
return True
|
|
# native
|
|
if (constraints.pool.native is not None and
|
|
constraints.pool.native != pool.native):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'native',
|
|
constraints.pool.native, pool.native)
|
|
return True
|
|
# autoscale disallow
|
|
if (constraints.pool.autoscale_allow is not None and
|
|
not constraints.pool.autoscale_allow and
|
|
cp.enable_auto_scale):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'autoscale_allow',
|
|
constraints.pool.autoscale_allow,
|
|
cp.enable_auto_scale)
|
|
return True
|
|
# autoscale exclusive
|
|
if (constraints.pool.autoscale_exclusive and
|
|
not cp.enable_auto_scale):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'autoscale_exclusive',
|
|
constraints.pool.autoscale_exclusive,
|
|
cp.enable_auto_scale)
|
|
return True
|
|
# low priority disallow
|
|
if (constraints.pool.low_priority_nodes_allow is not None and
|
|
not constraints.pool.low_priority_nodes_allow and
|
|
cp.target_low_priority_nodes > 0):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'low_priority_nodes_allow',
|
|
constraints.pool.low_priority_nodes_allow,
|
|
cp.target_low_priority_nodes)
|
|
return True
|
|
# low priority exclusive
|
|
if (constraints.pool.low_priority_nodes_exclusive and
|
|
cp.target_low_priority_nodes == 0 and
|
|
not cp.enable_auto_scale):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'low_priority_nodes_exclusive',
|
|
constraints.pool.low_priority_nodes_exclusive,
|
|
cp.target_low_priority_nodes)
|
|
return True
|
|
# exclusive
|
|
if constraints.compute_node.exclusive and cp.max_tasks_per_node > 1:
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'exclusive',
|
|
constraints.compute_node.exclusive,
|
|
cp.max_tasks_per_node)
|
|
return True
|
|
# vm size
|
|
if (is_not_empty(constraints.compute_node.vm_size) and
|
|
constraints.compute_node.vm_size != pool.vm_size):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'vm_size',
|
|
constraints.compute_node.vm_size,
|
|
pool.vm_size)
|
|
return True
|
|
# gpu
|
|
if (constraints.compute_node.gpu is not None and
|
|
constraints.compute_node.gpu != is_gpu_pool(pool.vm_size)):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'gpu',
|
|
constraints.compute_node.gpu, is_gpu_pool(pool.vm_size))
|
|
return True
|
|
# infiniband
|
|
if (constraints.compute_node.infiniband is not None and
|
|
constraints.compute_node.infiniband != is_rdma_pool(
|
|
pool.vm_size)):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'infiniband',
|
|
constraints.compute_node.infiniband,
|
|
is_rdma_pool(pool.vm_size))
|
|
return True
|
|
# cores
|
|
if (constraints.compute_node.cores is not None and
|
|
pool.vm_props is not None):
|
|
# absolute core filtering
|
|
if constraints.compute_node.cores > pool.vm_props.number_of_cores:
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'cores',
|
|
constraints.compute_node.cores,
|
|
pool.vm_props.number_of_cores)
|
|
return True
|
|
# core variance of zero must match the number of cores exactly
|
|
if constraints.compute_node.core_variance == 0:
|
|
if (constraints.compute_node.cores !=
|
|
pool.vm_props.number_of_cores):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'zero core_variance',
|
|
constraints.compute_node.cores,
|
|
pool.vm_props.number_of_cores)
|
|
return True
|
|
# core variance of None corresponds to no restrictions
|
|
# positive core variance infers maximum core matching
|
|
if (constraints.compute_node.core_variance is not None and
|
|
constraints.compute_node.core_variance > 0):
|
|
max_cc = constraints.compute_node.cores * (
|
|
1 + constraints.compute_node.core_variance)
|
|
if pool.vm_props.number_of_cores > max_cc:
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'max core_variance',
|
|
max_cc,
|
|
pool.vm_props.number_of_cores)
|
|
return True
|
|
# memory
|
|
if (constraints.compute_node.memory is not None and
|
|
pool.vm_props is not None):
|
|
vm_mem = pool.vm_props.memory_in_mb
|
|
# absolute memory filtering
|
|
if constraints.compute_node.memory > vm_mem:
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'memory',
|
|
constraints.compute_node.memory,
|
|
vm_mem)
|
|
return True
|
|
# memory variance of zero must match the memory amount exactly
|
|
if constraints.compute_node.memory_variance == 0:
|
|
if constraints.compute_node.memory != vm_mem:
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'zero memory_variance',
|
|
constraints.compute_node.memory,
|
|
vm_mem)
|
|
return True
|
|
# memory variance of None corresponds to no restrictions
|
|
# positive memory variance infers maximum memory matching
|
|
if (constraints.compute_node.memory_variance is not None and
|
|
constraints.compute_node.memory_variance > 0):
|
|
max_mem = constraints.compute_node.memory * (
|
|
1 + constraints.compute_node.memory_variance)
|
|
if vm_mem > max_mem:
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'max memory_variance',
|
|
max_mem,
|
|
vm_mem)
|
|
return True
|
|
# multi-instance
|
|
if (constraints.task.has_multi_instance and
|
|
not cp.enable_inter_node_communication):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'has_multi_instance',
|
|
constraints.task.has_multi_instance,
|
|
cp.enable_inter_node_communication)
|
|
return True
|
|
# registries
|
|
if is_not_empty(constraints.pool.registries):
|
|
for cr in constraints.pool.registries:
|
|
if not pool.has_registry_login(cr):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'registries',
|
|
cr if is_not_empty(cr) else 'dockerhub',
|
|
False)
|
|
return True
|
|
# hard constraint filtering passed
|
|
return False
|
|
|
|
def _filter_pool_nodes_with_constraints(
|
|
self,
|
|
pool: FederationPool,
|
|
constraints: Constraints,
|
|
unique_id: str,
|
|
) -> bool:
|
|
cp = pool.cloud_pool
|
|
# check for dedicated only execution
|
|
if (constraints.pool.low_priority_nodes_allow is not None and
|
|
not constraints.pool.low_priority_nodes_allow):
|
|
# if there are no schedulable dedicated nodes and
|
|
# if no autoscale is allowed or no autoscale formula exists
|
|
if (pool.schedulable_dedicated_nodes == 0 and
|
|
(not (constraints.pool.autoscale_allow and
|
|
cp.enable_auto_scale))):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'low_priority_nodes_allow',
|
|
constraints.pool.low_priority_nodes_allow,
|
|
pool.schedulable_dedicated_nodes)
|
|
return True
|
|
# check for low priority only execution
|
|
if constraints.pool.low_priority_nodes_exclusive:
|
|
# if there are no schedulable low pri nodes and
|
|
# if no autoscale is allowed or no autoscale formula exists
|
|
if (pool.schedulable_low_priority_nodes == 0 and
|
|
(not (constraints.pool.autoscale_allow and
|
|
cp.enable_auto_scale))):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'low_priority_nodes_allow',
|
|
constraints.pool.low_priority_nodes_allow,
|
|
pool.schedulable_dedicated_nodes)
|
|
return True
|
|
# max active task backlog ratio
|
|
if constraints.pool.max_active_task_backlog_ratio is not None:
|
|
schedulable_slots = (
|
|
pool.schedulable_dedicated_nodes +
|
|
pool.schedulable_low_priority_nodes
|
|
) * cp.max_tasks_per_node
|
|
if schedulable_slots > 0:
|
|
ratio = pool.active_tasks_count / schedulable_slots
|
|
else:
|
|
if (cp.enable_auto_scale and
|
|
cp.allocation_state ==
|
|
batchmodels.AllocationState.steady and
|
|
constraints.pool.
|
|
max_active_task_backlog_autoscale_exempt):
|
|
ratio = 0
|
|
else:
|
|
ratio = None
|
|
if (ratio is None or
|
|
ratio > constraints.pool.max_active_task_backlog_ratio):
|
|
self._log_constraint_failure(
|
|
unique_id, cp.id, 'max_active_task_backlog_ratio',
|
|
constraints.pool.max_active_task_backlog_ratio,
|
|
ratio)
|
|
return True
|
|
# node constraint filtering passed
|
|
return False
|
|
|
|
def _pre_constraint_filter_pool_update(
|
|
self,
|
|
bsh: BatchServiceHandler,
|
|
fdh: FederationDataHandler,
|
|
rk: str,
|
|
active_tasks_count_update: bool,
|
|
) -> bool:
|
|
pool = self.pools[rk]
|
|
# ensure pool is in federation (pools can be removed between
|
|
# federation updates)
|
|
if self.check_pool_in_federation(fdh, rk):
|
|
# refresh pool
|
|
if pool.pool_requires_update:
|
|
pool.cloud_pool = bsh.get_pool_full_update(
|
|
pool.batch_account, pool.service_url, pool.pool_id)
|
|
# refresh node state counts
|
|
if pool.node_counts_requires_update:
|
|
pool.node_counts = bsh.get_node_state_counts(
|
|
pool.batch_account, pool.service_url, pool.pool_id)
|
|
# refresh active task counts
|
|
if (active_tasks_count_update and
|
|
pool.active_tasks_count_requires_update):
|
|
pool.active_tasks_count = \
|
|
bsh.aggregate_active_task_count_on_pool(
|
|
pool.batch_account, pool.service_url, pool.pool_id)
|
|
else:
|
|
logger.warning(
|
|
'pool id {} hash={} not in fed id {} fed hash {}'.format(
|
|
pool.pool_id, rk, self.id, self.hash))
|
|
return False
|
|
return True
|
|
|
|
def _select_pool_for_target_required(
|
|
self,
|
|
unique_id: str,
|
|
using_slots: bool,
|
|
target_required: int,
|
|
allow_autoscale: bool,
|
|
num_pools: Dict[str, int],
|
|
binned: List[str],
|
|
pool_map: Dict[str, Dict[str, int]],
|
|
) -> Optional[str]:
|
|
logger.debug(
|
|
'pool selection attempt for uid={} using_slots={} '
|
|
'target_required={} allow_autoscale={} num_pools={} '
|
|
'binned={}'.format(
|
|
unique_id, using_slots, target_required, allow_autoscale,
|
|
num_pools, binned))
|
|
# try to match against largest idle pool with sufficient capacity
|
|
if num_pools['idle'] > 0:
|
|
for rk in binned['idle']:
|
|
if pool_map['idle'][rk] >= target_required:
|
|
return rk
|
|
# try to match against largest avail pool with sufficient capacity
|
|
if num_pools['avail'] > 0:
|
|
for rk in binned['avail']:
|
|
if pool_map['avail'][rk] >= target_required:
|
|
return rk
|
|
# try to match against any autoscale-enabled pool that is steady
|
|
if allow_autoscale:
|
|
for rk in binned['idle']:
|
|
pool = self.pools[rk]
|
|
if (pool.cloud_pool.enable_auto_scale and
|
|
pool.cloud_pool.allocation_state ==
|
|
batchmodels.AllocationState.steady):
|
|
return rk
|
|
for rk in binned['avail']:
|
|
pool = self.pools[rk]
|
|
if (pool.cloud_pool.enable_auto_scale and
|
|
pool.cloud_pool.allocation_state ==
|
|
batchmodels.AllocationState.steady):
|
|
return rk
|
|
# if using slot scheduling, then attempt to schedule with backlog
|
|
if using_slots:
|
|
# try to match against largest idle pool
|
|
if num_pools['idle'] > 0:
|
|
for rk in binned['idle']:
|
|
if pool_map['idle'][rk] >= 1:
|
|
return rk
|
|
# try to match against largest avail pool
|
|
if num_pools['avail'] > 0:
|
|
for rk in binned['avail']:
|
|
if pool_map['avail'][rk] >= 1:
|
|
return rk
|
|
return None
|
|
|
|
def _greedy_best_fit_match_for_job(
|
|
self,
|
|
num_tasks: int,
|
|
constraints: Constraints,
|
|
unique_id: str,
|
|
dedicated_vms: Dict[str, Dict[str, int]],
|
|
dedicated_slots: Dict[str, Dict[str, int]],
|
|
low_priority_vms: Dict[str, Dict[str, int]],
|
|
low_priority_slots: Dict[str, Dict[str, int]],
|
|
) -> Optional[str]:
|
|
# calculate pools of each
|
|
num_pools = {
|
|
'vms': {
|
|
'dedicated': {
|
|
'idle': len(dedicated_vms['idle']),
|
|
'avail': len(dedicated_vms['avail']),
|
|
},
|
|
'low_priority': {
|
|
'idle': len(low_priority_vms['idle']),
|
|
'avail': len(low_priority_vms['avail']),
|
|
}
|
|
},
|
|
'slots': {
|
|
'dedicated': {
|
|
'idle': len(dedicated_slots['idle']),
|
|
'avail': len(dedicated_slots['avail']),
|
|
},
|
|
'low_priority': {
|
|
'idle': len(low_priority_slots['idle']),
|
|
'avail': len(low_priority_slots['avail']),
|
|
}
|
|
}
|
|
}
|
|
# bin all maps
|
|
binned = {
|
|
'vms': {
|
|
'dedicated': {
|
|
'idle': sorted(
|
|
dedicated_vms['idle'],
|
|
key=dedicated_vms['idle'].get,
|
|
reverse=True),
|
|
'avail': sorted(
|
|
dedicated_vms['avail'],
|
|
key=dedicated_vms['avail'].get,
|
|
reverse=True),
|
|
},
|
|
'low_priority': {
|
|
'idle': sorted(
|
|
low_priority_vms['idle'],
|
|
key=low_priority_vms['idle'].get,
|
|
reverse=True),
|
|
'avail': sorted(
|
|
low_priority_vms['avail'],
|
|
key=low_priority_vms['avail'].get,
|
|
reverse=True),
|
|
}
|
|
},
|
|
'slots': {
|
|
'dedicated': {
|
|
'idle': sorted(
|
|
dedicated_slots['idle'],
|
|
key=dedicated_slots['idle'].get,
|
|
reverse=True),
|
|
'avail': sorted(
|
|
dedicated_slots['avail'],
|
|
key=dedicated_slots['avail'].get,
|
|
reverse=True),
|
|
},
|
|
'low_priority': {
|
|
'idle': sorted(
|
|
low_priority_slots['idle'],
|
|
key=low_priority_slots['idle'].get,
|
|
reverse=True),
|
|
'avail': sorted(
|
|
low_priority_slots['avail'],
|
|
key=low_priority_slots['avail'].get,
|
|
reverse=True),
|
|
}
|
|
}
|
|
}
|
|
# scheduling is done by slots (regular tasks) or vms (multi-instance)
|
|
if constraints.task.has_multi_instance:
|
|
total_slots_required = None
|
|
vms_required_per_task = constraints.task.instance_counts_max
|
|
else:
|
|
total_slots_required = constraints.task.instance_counts_total
|
|
vms_required_per_task = None
|
|
# greedy smallest-fit (by vms or slots) matching
|
|
selected = None
|
|
# constraint: dedicated only pools
|
|
if (constraints.pool.low_priority_nodes_allow is not None and
|
|
not constraints.pool.low_priority_nodes_allow):
|
|
if total_slots_required is not None:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, True, total_slots_required,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['slots']['dedicated'],
|
|
binned['slots']['dedicated'], dedicated_slots)
|
|
else:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, False, vms_required_per_task,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['vms']['dedicated'],
|
|
binned['vms']['dedicated'], dedicated_vms)
|
|
elif constraints.pool.low_priority_nodes_exclusive:
|
|
# constraint: low priority only pools
|
|
if total_slots_required is not None:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, True, total_slots_required,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['slots']['low_priority'],
|
|
binned['slots']['low_priority'], low_priority_slots)
|
|
else:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, False, vms_required_per_task,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['vms']['low_priority'],
|
|
binned['vms']['low_priority'], low_priority_vms)
|
|
else:
|
|
# no constraints, try scheduling on dedicated first, then low pri
|
|
if total_slots_required is not None:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, True, total_slots_required,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['slots']['dedicated'],
|
|
binned['slots']['dedicated'], dedicated_slots)
|
|
if selected is None:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, True, total_slots_required,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['slots']['low_priority'],
|
|
binned['slots']['low_priority'], low_priority_slots)
|
|
else:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, False, vms_required_per_task,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['vms']['dedicated'],
|
|
binned['vms']['dedicated'], dedicated_vms)
|
|
if selected is None:
|
|
selected = self._select_pool_for_target_required(
|
|
unique_id, False, vms_required_per_task,
|
|
constraints.pool.autoscale_allow,
|
|
num_pools['vms']['low_priority'],
|
|
binned['vms']['low_priority'], low_priority_vms)
|
|
return selected
|
|
|
|
def find_target_pool_for_job(
|
|
self,
|
|
bsh: BatchServiceHandler,
|
|
fdh: FederationDataHandler,
|
|
num_tasks: int,
|
|
constraints: Constraints,
|
|
blacklist: Set[str],
|
|
unique_id: str,
|
|
target: str,
|
|
) -> Optional[str]:
|
|
"""
|
|
This function should be called with lock already held!
|
|
"""
|
|
dedicated_vms = {
|
|
'idle': {},
|
|
'avail': {},
|
|
}
|
|
dedicated_slots = {
|
|
'idle': {},
|
|
'avail': {},
|
|
}
|
|
low_priority_vms = {
|
|
'idle': {},
|
|
'avail': {},
|
|
}
|
|
low_priority_slots = {
|
|
'idle': {},
|
|
'avail': {},
|
|
}
|
|
# check and update pools in parallel
|
|
update_futures = {}
|
|
if len(self.pools) > 0:
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers_for_executor(
|
|
self.pools)) as executor:
|
|
for rk in self.pools:
|
|
if rk in blacklist:
|
|
continue
|
|
update_futures[rk] = executor.submit(
|
|
self._pre_constraint_filter_pool_update, bsh, fdh, rk,
|
|
constraints.pool.max_active_task_backlog_ratio
|
|
is not None
|
|
)
|
|
# perform constraint filtering
|
|
# TODO optimization -> fast match against last schedule?
|
|
for rk in self.pools:
|
|
pool = self.pools[rk]
|
|
if rk in blacklist:
|
|
continue
|
|
# check if update was successful for pool
|
|
if not update_futures[rk].result():
|
|
continue
|
|
# ensure pool is valid and node counts exist
|
|
if not pool.is_valid or pool.node_counts is None:
|
|
logger.warning(
|
|
'skipping invalid pool id {} hash={} node counts '
|
|
'valid={} in fed id {} fed hash {} uid={} '
|
|
'target={}'.format(
|
|
pool.pool_id, rk, pool.node_counts is not None,
|
|
self.id, self.hash, unique_id, target))
|
|
continue
|
|
# hard constraint filtering
|
|
if self._filter_pool_with_hard_constraints(
|
|
pool, constraints, unique_id):
|
|
blacklist.add(rk)
|
|
continue
|
|
# further constraint matching for nodes
|
|
if self._filter_pool_nodes_with_constraints(
|
|
pool, constraints, unique_id):
|
|
continue
|
|
# add counts for pre-sort
|
|
if pool.node_counts.dedicated.idle > 0:
|
|
dedicated_vms['idle'][rk] = pool.node_counts.dedicated.idle
|
|
dedicated_slots['idle'][rk] = (
|
|
pool.node_counts.dedicated.idle *
|
|
pool.cloud_pool.max_tasks_per_node
|
|
)
|
|
if pool.node_counts.low_priority.idle > 0:
|
|
low_priority_vms['idle'][rk] = (
|
|
pool.node_counts.low_priority.idle
|
|
)
|
|
low_priority_slots['idle'][rk] = (
|
|
pool.node_counts.low_priority.idle *
|
|
pool.cloud_pool.max_tasks_per_node
|
|
)
|
|
# for availbility counts, allow pools to be added to map even
|
|
# with zero nodes if they can autoscale
|
|
if (pool.schedulable_dedicated_nodes > 0 or
|
|
pool.cloud_pool.enable_auto_scale):
|
|
dedicated_vms['avail'][rk] = pool.schedulable_dedicated_nodes
|
|
dedicated_slots['avail'][rk] = (
|
|
pool.schedulable_dedicated_nodes *
|
|
pool.cloud_pool.max_tasks_per_node
|
|
)
|
|
if (pool.schedulable_low_priority_nodes > 0 or
|
|
pool.cloud_pool.enable_auto_scale):
|
|
low_priority_vms['avail'][rk] = (
|
|
pool.schedulable_low_priority_nodes
|
|
)
|
|
low_priority_slots['avail'][rk] = (
|
|
pool.schedulable_low_priority_nodes *
|
|
pool.cloud_pool.max_tasks_per_node
|
|
)
|
|
del update_futures
|
|
# check for non-availability
|
|
if (len(dedicated_vms['avail']) == 0 and
|
|
len(low_priority_vms['avail']) == 0 and
|
|
not constraints.pool.autoscale_allow):
|
|
logger.error(
|
|
'no available nodes to schedule uid {} target={} in fed {} '
|
|
'fed hash {}'.format(unique_id, target, self.id, self.hash))
|
|
if len(blacklist) == len(self.pools):
|
|
fdh.add_blocked_action_for_job(
|
|
self.hash, target, unique_id, num_tasks,
|
|
'Constraint filtering: all pools blacklisted')
|
|
else:
|
|
fdh.add_blocked_action_for_job(
|
|
self.hash, target, unique_id, num_tasks,
|
|
'Constraint filtering: no available pools')
|
|
return None
|
|
# perform greedy matching
|
|
schedule = self._greedy_best_fit_match_for_job(
|
|
num_tasks, constraints, unique_id, dedicated_vms, dedicated_slots,
|
|
low_priority_vms, low_priority_slots)
|
|
if schedule is None:
|
|
logger.warning(
|
|
'could not match uid {} target={} in fed {} fed hash {} to '
|
|
'any pool'.format(unique_id, target, self.id, self.hash))
|
|
fdh.add_blocked_action_for_job(
|
|
self.hash, target, unique_id, num_tasks,
|
|
'Pool matching: no available pools or nodes')
|
|
else:
|
|
logger.info(
|
|
'selected pool id {} hash {} for uid {} target={} in fed {} '
|
|
'fed hash {}'.format(
|
|
self.pools[schedule].pool_id, schedule, unique_id,
|
|
target, self.id, self.hash))
|
|
return schedule
|
|
|
|
async def create_job_schedule(
|
|
self,
|
|
bsh: BatchServiceHandler,
|
|
target_pool: str,
|
|
jobschedule: batchmodels.JobScheduleAddParameter,
|
|
constraints: Constraints,
|
|
) -> bool:
|
|
"""
|
|
This function should be called with lock already held!
|
|
"""
|
|
# get pool ref
|
|
pool = self.pools[target_pool]
|
|
# overwrite pool id in job schedule
|
|
jobschedule.job_specification.pool_info.pool_id = pool.pool_id
|
|
# add job schedule
|
|
try:
|
|
logger.info(
|
|
'adding job schedule {} to pool {} (batch_account={} '
|
|
'service_url={})'.format(
|
|
jobschedule.id, pool.pool_id, pool.batch_account,
|
|
pool.service_url))
|
|
bsh.add_job_schedule(
|
|
pool.batch_account, pool.service_url, jobschedule)
|
|
success = True
|
|
except batchmodels.BatchErrorException as exc:
|
|
if 'marked for deletion' in exc.message.value:
|
|
logger.error(
|
|
'cannot reuse job shcedule {} being deleted on '
|
|
'pool {}'.format(jobschedule.id, pool.pool_id))
|
|
elif 'already exists' in exc.message.value:
|
|
logger.error(
|
|
'cannot reuse existing job shcedule {} on '
|
|
'pool {}'.format(jobschedule.id, pool.pool_id))
|
|
else:
|
|
logger.exception(str(exc))
|
|
await bsh.delete_or_terminate_job(
|
|
pool.batch_account, pool.service_url, jobschedule.id,
|
|
True, True, wait=True)
|
|
success = False
|
|
return success
|
|
|
|
async def create_job(
|
|
self,
|
|
bsh: BatchServiceHandler,
|
|
target_pool: str,
|
|
job: batchmodels.JobAddParameter,
|
|
constraints: Constraints,
|
|
) -> bool:
|
|
"""
|
|
This function should be called with lock already held!
|
|
"""
|
|
# get pool ref
|
|
pool = self.pools[target_pool]
|
|
# overwrite pool id in job
|
|
job.pool_info.pool_id = pool.pool_id
|
|
# fixup jp env vars
|
|
if (job.job_preparation_task is not None and
|
|
job.job_preparation_task.environment_settings is not None):
|
|
replace_ev = []
|
|
for ev in job.job_preparation_task.environment_settings:
|
|
if ev.name == 'SINGULARITY_CACHEDIR':
|
|
replace_ev.append(batchmodels.EnvironmentSetting(
|
|
ev.name,
|
|
'{}/singularity/cache'.format(
|
|
get_temp_disk_for_node_agent(
|
|
pool.cloud_pool.
|
|
virtual_machine_configuration.
|
|
node_agent_sku_id.lower()))
|
|
))
|
|
else:
|
|
replace_ev.append(ev)
|
|
job.job_preparation_task.environment_settings = replace_ev
|
|
# add job
|
|
success = False
|
|
del_job = True
|
|
try:
|
|
logger.info(
|
|
'adding job {} to pool {} (batch_account={} '
|
|
'service_url={})'.format(
|
|
job.id, pool.pool_id, pool.batch_account,
|
|
pool.service_url))
|
|
bsh.add_job(pool.batch_account, pool.service_url, job)
|
|
success = True
|
|
del_job = False
|
|
except batchmodels.BatchErrorException as exc:
|
|
if 'marked for deletion' in exc.message.value:
|
|
del_job = False
|
|
logger.error(
|
|
'cannot reuse job {} being deleted on pool {}'.format(
|
|
job.id, pool.pool_id))
|
|
elif 'already in a completed state' in exc.message.value:
|
|
del_job = False
|
|
logger.error(
|
|
'cannot reuse completed job {} on pool {}'.format(
|
|
job.id, pool.pool_id))
|
|
elif 'job already exists' in exc.message.value:
|
|
del_job = False
|
|
success = True
|
|
# cannot re-use an existing job if multi-instance due to
|
|
# job release requirement
|
|
if (constraints.task.has_multi_instance and
|
|
constraints.task.auto_complete):
|
|
logger.error(
|
|
'cannot reuse job {} on pool {} with multi_instance '
|
|
'and auto_complete'.format(job.id, pool.pool_id))
|
|
success = False
|
|
else:
|
|
# retrieve job and check for constraints
|
|
ej = bsh.get_job(
|
|
pool.batch_account, pool.service_url, job.id)
|
|
# ensure the job's pool info matches
|
|
if ej.pool_info.pool_id != pool.pool_id:
|
|
logger.error(
|
|
'existing job {} on pool {} is already assigned '
|
|
'to a different pool {}'.format(
|
|
job.id, pool.pool_id, ej.pool_info.pool_id))
|
|
success = False
|
|
else:
|
|
# ensure job prep command line is the same (this will
|
|
# prevent jobs with mismatched data ingress)
|
|
ejp = None
|
|
njp = None
|
|
if ej.job_preparation_task is not None:
|
|
ejp = ej.job_preparation_task.command_line
|
|
if job.job_preparation_task is not None:
|
|
njp = job.job_preparation_task.command_line
|
|
if ejp != njp:
|
|
success = False
|
|
else:
|
|
success = False
|
|
else:
|
|
if job.job_preparation_task is not None:
|
|
njp = job.job_preparation_task.command_line
|
|
success = False
|
|
if not success:
|
|
logger.error(
|
|
'existing job {} on pool {} has an '
|
|
'incompatible job prep task: existing={} '
|
|
'desired={}'.format(
|
|
job.id, pool.pool_id, ejp, njp))
|
|
elif (job.uses_task_dependencies and
|
|
not ej.uses_task_dependencies):
|
|
# check for task dependencies
|
|
logger.error(
|
|
('existing job {} on pool {} has an '
|
|
'incompatible task dependency setting: '
|
|
'existing={} desired={}').format(
|
|
job.id, pool.pool_id,
|
|
ej.uses_task_dependencies,
|
|
job.uses_task_dependencies))
|
|
success = False
|
|
elif (ej.on_task_failure != job.on_task_failure):
|
|
# check for job actions
|
|
logger.error(
|
|
('existing job {} on pool {} has an '
|
|
'incompatible on_task_failure setting: '
|
|
'existing={} desired={}').format(
|
|
job.id, pool.pool_id,
|
|
ej.on_task_failure.value,
|
|
job.on_task_failure.value))
|
|
success = False
|
|
else:
|
|
logger.exception(str(exc))
|
|
if del_job:
|
|
await bsh.delete_or_terminate_job(
|
|
pool.batch_account, pool.service_url, job.id, True, False,
|
|
wait=True)
|
|
return success
|
|
|
|
def track_job(
|
|
self,
|
|
fdh: FederationDataHandler,
|
|
target_pool: str,
|
|
job_id: str,
|
|
is_job_schedule: bool,
|
|
unique_id: Optional[str],
|
|
) -> None:
|
|
# get pool ref
|
|
pool = self.pools[target_pool]
|
|
# add to jobs table
|
|
while True:
|
|
entity = fdh.get_location_entity_for_job(self.hash, job_id, pool)
|
|
if entity is None:
|
|
pk, rk = fdh.generate_pk_rk_for_job_location_entity(
|
|
self.hash, job_id, pool)
|
|
entity = {
|
|
'PartitionKey': pk,
|
|
'RowKey': rk,
|
|
'Kind': 'job_schedule' if is_job_schedule else 'job',
|
|
'Id': job_id,
|
|
'PoolId': pool.pool_id,
|
|
'BatchAccount': pool.batch_account,
|
|
'ServiceUrl': pool.service_url,
|
|
}
|
|
if is_not_empty(unique_id):
|
|
entity['UniqueIds'] = unique_id
|
|
entity['AdditionTimestamps'] = datetime_utcnow(
|
|
as_string=True)
|
|
else:
|
|
if is_not_empty(unique_id):
|
|
try:
|
|
entity['AdditionTimestamps'] = '{},{}'.format(
|
|
entity['AdditionTimestamps'], datetime_utcnow(
|
|
as_string=True))
|
|
except KeyError:
|
|
entity['AdditionTimestamps'] = datetime_utcnow(
|
|
as_string=True)
|
|
if (len(entity['AdditionTimestamps']) >
|
|
fdh._MAX_STR_ENTITY_PROPERTY_LENGTH):
|
|
tmp = entity['AdditionTimestamps'].split(',')
|
|
entity['AdditionTimestamps'] = ','.join(tmp[-32:])
|
|
del tmp
|
|
try:
|
|
entity['UniqueIds'] = '{},{}'.format(
|
|
entity['UniqueIds'], unique_id)
|
|
except KeyError:
|
|
entity['UniqueIds'] = unique_id
|
|
if (len(entity['UniqueIds']) >
|
|
fdh._MAX_STR_ENTITY_PROPERTY_LENGTH):
|
|
tmp = entity['UniqueIds'].split(',')
|
|
entity['UniqueIds'] = ','.join(tmp[-32:])
|
|
del tmp
|
|
if fdh.insert_or_update_entity_with_etag_for_job(entity):
|
|
logger.debug(
|
|
'upserted location entity for job {} on pool {} uid={} '
|
|
'(batch_account={} service_url={})'.format(
|
|
job_id, pool.pool_id, unique_id, pool.batch_account,
|
|
pool.service_url))
|
|
break
|
|
else:
|
|
logger.debug(
|
|
'conflict upserting location entity for job {} on '
|
|
'pool {} uid={}(batch_account={} service_url={})'.format(
|
|
job_id, pool.pool_id, unique_id, pool.batch_account,
|
|
pool.service_url))
|
|
|
|
def fixup_task_for_mismatch(
|
|
self,
|
|
node_agent: str,
|
|
ib_mismatch: bool,
|
|
task: batchmodels.TaskAddParameter,
|
|
constraints: Constraints,
|
|
) -> batchmodels.TaskAddParameter:
|
|
# fix up env vars for gpu and/or non-native
|
|
if ((constraints.compute_node.gpu or not constraints.pool.native) and
|
|
task.environment_settings is not None):
|
|
replace_ev = []
|
|
for ev in task.environment_settings:
|
|
if ev.name == 'CUDA_CACHE_PATH':
|
|
replace_ev.append(batchmodels.EnvironmentSetting(
|
|
ev.name,
|
|
'{}/batch/tasks/.nv/ComputeCache'.format(
|
|
get_temp_disk_for_node_agent(node_agent))
|
|
))
|
|
elif ev.name == 'SINGULARITY_CACHEDIR':
|
|
replace_ev.append(batchmodels.EnvironmentSetting(
|
|
ev.name,
|
|
'{}/singularity/cache'.format(
|
|
get_temp_disk_for_node_agent(node_agent))
|
|
))
|
|
else:
|
|
replace_ev.append(ev)
|
|
task.environment_settings = replace_ev
|
|
# fix up ib rdma mapping in command line
|
|
if ib_mismatch:
|
|
if node_agent.startswith('batch.node.sles'):
|
|
final = (
|
|
'/etc/dat.conf:/etc/rdma/dat.conf:ro '
|
|
'--device=/dev/hvnd_rdma'
|
|
)
|
|
else:
|
|
final = '/etc/dat.conf:/etc/rdma/dat.conf:ro'
|
|
if constraints.task.has_multi_instance:
|
|
# fixup coordination command line
|
|
cc = task.multi_instance_settings.coordination_command_line
|
|
cc = cc.replace(
|
|
'/etc/rdma:/etc/rdma:ro',
|
|
'/etc/dat.conf:/etc/dat.conf:ro').replace(
|
|
'/etc/rdma/dat.conf:/etc/dat.conf:ro', final)
|
|
task.multi_instance_settings.coordination_command_line = cc
|
|
# fixup command line
|
|
task.command_line = task.command_line.replace(
|
|
'/etc/rdma:/etc/rdma:ro',
|
|
'/etc/dat.conf:/etc/dat.conf:ro').replace(
|
|
'/etc/rdma/dat.conf:/etc/dat.conf:ro', final)
|
|
return task
|
|
|
|
def schedule_tasks(
|
|
self,
|
|
bsh: BatchServiceHandler,
|
|
fdh: FederationDataHandler,
|
|
target_pool: str,
|
|
job_id: str,
|
|
constraints: Constraints,
|
|
naming: TaskNaming,
|
|
task_map: Dict[str, batchmodels.TaskAddParameter],
|
|
) -> None:
|
|
"""
|
|
This function should be called with lock already held!
|
|
"""
|
|
# get pool ref
|
|
pool = self.pools[target_pool]
|
|
na = pool.cloud_pool.virtual_machine_configuration.\
|
|
node_agent_sku_id.lower()
|
|
# check if there is an ib mismatch
|
|
ib_mismatch = (
|
|
is_rdma_pool(pool.vm_size) and
|
|
not na.startswith('batch.node.centos')
|
|
)
|
|
task_ids = sorted(task_map.keys())
|
|
# fixup tasks directly if task dependencies are present
|
|
if constraints.task.has_task_dependencies:
|
|
for tid in task_ids:
|
|
task_map[tid] = self.fixup_task_for_mismatch(
|
|
na, ib_mismatch, task_map[tid], constraints)
|
|
else:
|
|
# re-assign task ids to current job if no task dependencies
|
|
# 1. sort task map keys
|
|
# 2. re-map task ids to current job
|
|
# 3. re-gather merge task dependencies (shouldn't happen)
|
|
last_tid = None
|
|
tasklist = None
|
|
merge_task_id = None
|
|
for tid in task_ids:
|
|
is_merge_task = tid == constraints.task.merge_task_id
|
|
tasklist, new_tid = bsh.regenerate_next_generic_task_id(
|
|
pool.batch_account, pool.service_url, job_id, naming, tid,
|
|
last_task_id=last_tid, tasklist=tasklist,
|
|
is_merge_task=is_merge_task)
|
|
task = task_map.pop(tid)
|
|
task = self.fixup_task_for_mismatch(
|
|
na, ib_mismatch, task, constraints)
|
|
task.id = new_tid
|
|
task_map[new_tid] = task
|
|
if is_merge_task:
|
|
merge_task_id = new_tid
|
|
tasklist.append(new_tid)
|
|
last_tid = new_tid
|
|
if merge_task_id is not None:
|
|
merge_task = task_map.pop(merge_task_id)
|
|
merge_task = self.fixup_task_for_mismatch(
|
|
na, ib_mismatch, merge_task, constraints)
|
|
merge_task.depends_on = batchmodels.TaskDependencies(
|
|
task_ids=list(task_map.keys()),
|
|
)
|
|
task_map[merge_task_id] = merge_task
|
|
# submit task collection
|
|
bsh.add_task_collection(
|
|
pool.batch_account, pool.service_url, job_id, task_map)
|
|
# set auto complete
|
|
if constraints.task.auto_complete:
|
|
bsh.set_auto_complete_on_job(
|
|
pool.batch_account, pool.service_url, job_id)
|
|
# post scheduling actions
|
|
pool.on_new_tasks_scheduled(
|
|
bsh, fdh.scheduling_blackout, fdh.scheduling_evaluate_autoscale)
|
|
|
|
|
|
class FederationProcessor():
|
|
def __init__(self, config: Dict[str, Any]) -> None:
|
|
"""Ctor for FederationProcessor
|
|
:param config: configuration
|
|
"""
|
|
self._service_proxy = ServiceProxy(config)
|
|
try:
|
|
self.fed_refresh_interval = int(config['refresh_intervals'].get(
|
|
'federations', 30))
|
|
except KeyError:
|
|
self.fed_refresh_interval = 30
|
|
try:
|
|
self.action_refresh_interval = int(config['refresh_intervals'].get(
|
|
'actions', 5))
|
|
except KeyError:
|
|
self.action_refresh_interval = 5
|
|
self.csh = ComputeServiceHandler(self._service_proxy)
|
|
self.bsh = BatchServiceHandler(self._service_proxy)
|
|
self.fdh = FederationDataHandler(self._service_proxy)
|
|
# data structs
|
|
self._federation_lock = threading.Lock()
|
|
self.federations = {} # type: Dict[str, Federation]
|
|
|
|
@property
|
|
def federations_available(self) -> bool:
|
|
with self._federation_lock:
|
|
return len(self.federations) > 0
|
|
|
|
def _update_federation(self, entity) -> None:
|
|
fedhash = entity['RowKey']
|
|
fedid = entity['FederationId']
|
|
if fedhash not in self.federations:
|
|
logger.debug('adding federation hash {} id: {}'.format(
|
|
fedhash, fedid))
|
|
self.federations[fedhash] = Federation(fedhash, fedid)
|
|
pools = list(self.fdh.get_all_pools_for_federation(fedhash))
|
|
if len(pools) == 0:
|
|
return
|
|
poolset = set()
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers_for_executor(pools)) as executor:
|
|
for pool in pools:
|
|
executor.submit(
|
|
self.federations[fedhash].update_pool,
|
|
self.csh, self.bsh, pool, poolset)
|
|
self.federations[fedhash].trim_orphaned_pools(poolset)
|
|
|
|
def update_federations(self) -> None:
|
|
"""Update federations"""
|
|
entities = list(self.fdh.get_all_federations())
|
|
if len(entities) == 0:
|
|
return
|
|
with self._federation_lock:
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers_for_executor(
|
|
entities)) as executor:
|
|
for entity in entities:
|
|
executor.submit(self._update_federation, entity)
|
|
|
|
async def add_job_v1(
|
|
self,
|
|
fedhash: str,
|
|
job: batchmodels.JobAddParameter,
|
|
constraints: Constraints,
|
|
naming: TaskNaming,
|
|
task_map: Dict[str, batchmodels.TaskAddParameter],
|
|
unique_id: str
|
|
) -> bool:
|
|
# get the number of tasks in job
|
|
# try to match the appropriate pool for the tasks in job
|
|
# add job to pool
|
|
# if job exists, ensure settings match
|
|
# add tasks to job
|
|
# record mapping in fedjobs table
|
|
num_tasks = len(task_map)
|
|
logger.debug(
|
|
'attempting to match job {} with {} tasks in fed {} uid={}'.format(
|
|
job.id, num_tasks, fedhash, unique_id))
|
|
blacklist = set()
|
|
while True:
|
|
poolrk = self.federations[fedhash].find_target_pool_for_job(
|
|
self.bsh, self.fdh, num_tasks, constraints, blacklist,
|
|
unique_id, job.id)
|
|
if poolrk is not None:
|
|
cj = await self.federations[fedhash].create_job(
|
|
self.bsh, poolrk, job, constraints)
|
|
if cj:
|
|
# remove blocked action if any
|
|
self.fdh.remove_blocked_action_for_job(fedhash, job.id)
|
|
# track job prior to adding tasks in case task
|
|
# addition fails
|
|
self.federations[fedhash].track_job(
|
|
self.fdh, poolrk, job.id, False, None)
|
|
# schedule tasks
|
|
self.federations[fedhash].schedule_tasks(
|
|
self.bsh, self.fdh, poolrk, job.id, constraints,
|
|
naming, task_map)
|
|
# update job tracking
|
|
self.federations[fedhash].track_job(
|
|
self.fdh, poolrk, job.id, False, unique_id)
|
|
break
|
|
else:
|
|
logger.debug(
|
|
'blacklisting pool hash={} in fed hash {} '
|
|
'uid={} for job {}'.format(
|
|
poolrk, fedhash, unique_id, job.id))
|
|
blacklist.add(poolrk)
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
async def add_job_schedule_v1(
|
|
self,
|
|
fedhash: str,
|
|
job_schedule: batchmodels.JobScheduleAddParameter,
|
|
constraints: Constraints,
|
|
unique_id: str
|
|
) -> bool:
|
|
# ensure there is no existing job schedule. although this is checked
|
|
# at submission time, a similarly named job schedule can be enqueued
|
|
# multiple times before the action is dequeued
|
|
if self.fdh.location_entities_exist_for_job(fedhash, job_schedule.id):
|
|
logger.error(
|
|
'job schedule {} already exists for fed {} uid={}'.format(
|
|
job_schedule.id, fedhash, unique_id))
|
|
return True
|
|
num_tasks = constraints.task.tasks_per_recurrence
|
|
logger.debug(
|
|
'attempting to match job schedule {} with {} tasks in fed {} '
|
|
'uid={}'.format(job_schedule.id, num_tasks, fedhash, unique_id))
|
|
blacklist = set()
|
|
while True:
|
|
poolrk = self.federations[fedhash].find_target_pool_for_job(
|
|
self.bsh, self.fdh, num_tasks, constraints, blacklist,
|
|
unique_id, job_schedule.id)
|
|
if poolrk is not None:
|
|
cj = await self.federations[fedhash].create_job_schedule(
|
|
self.bsh, poolrk, job_schedule, constraints)
|
|
if cj:
|
|
# remove blocked action if any
|
|
self.fdh.remove_blocked_action_for_job(
|
|
fedhash, job_schedule.id)
|
|
# track job schedule
|
|
self.federations[fedhash].track_job(
|
|
self.fdh, poolrk, job_schedule.id, True, unique_id)
|
|
break
|
|
else:
|
|
logger.debug(
|
|
'blacklisting pool hash={} in fed hash {} '
|
|
'uid={} for job schedule {}'.format(
|
|
poolrk, fedhash, unique_id, job_schedule.id))
|
|
blacklist.add(poolrk)
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
async def _terminate_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
is_job_schedule: bool,
|
|
entity: azure.cosmosdb.table.models.Entity,
|
|
) -> None:
|
|
if 'TerminateTimestamp' in entity:
|
|
logger.debug(
|
|
'{} {} for fed {} has already been terminated '
|
|
'at {}'.format(
|
|
'job schedule' if is_job_schedule else 'job',
|
|
job_id, fedhash, entity['TerminateTimestamp']))
|
|
return
|
|
await self.bsh.delete_or_terminate_job(
|
|
entity['BatchAccount'], entity['ServiceUrl'], job_id, False,
|
|
is_job_schedule, wait=False)
|
|
logger.info(
|
|
'terminated {} {} on pool {} for fed {} (batch_account={} '
|
|
'service_url={}'.format(
|
|
'job schedule' if is_job_schedule else 'job',
|
|
job_id, entity['PoolId'], fedhash, entity['BatchAccount'],
|
|
entity['ServiceUrl']))
|
|
while True:
|
|
entity['TerminateTimestamp'] = datetime_utcnow(as_string=False)
|
|
if self.fdh.insert_or_update_entity_with_etag_for_job(entity):
|
|
break
|
|
else:
|
|
# force update
|
|
entity['etag'] = '*'
|
|
|
|
async def _delete_job(
|
|
self,
|
|
fedhash: str,
|
|
job_id: str,
|
|
is_job_schedule: bool,
|
|
entity: azure.cosmosdb.table.models.Entity,
|
|
) -> None:
|
|
await self.bsh.delete_or_terminate_job(
|
|
entity['BatchAccount'], entity['ServiceUrl'], job_id, True,
|
|
is_job_schedule, wait=False)
|
|
logger.info(
|
|
'deleted {} {} on pool {} for fed {} (batch_account={} '
|
|
'service_url={}'.format(
|
|
'job schedule' if is_job_schedule else 'job',
|
|
job_id, entity['PoolId'], fedhash, entity['BatchAccount'],
|
|
entity['ServiceUrl']))
|
|
self.fdh.delete_location_entity_for_job(entity)
|
|
|
|
async def delete_or_terminate_job_v1(
|
|
self,
|
|
delete: bool,
|
|
fedhash: str,
|
|
job_id: str,
|
|
is_job_schedule: bool,
|
|
unique_id: str
|
|
) -> None:
|
|
# find all jobs across federation mathching the id
|
|
entities = self.fdh.get_all_location_entities_for_job(fedhash, job_id)
|
|
# terminate each pool-level job representing federation job
|
|
tasks = []
|
|
coro = self._delete_job if delete else self._terminate_job
|
|
for entity in entities:
|
|
tasks.append(
|
|
asyncio.ensure_future(
|
|
coro(fedhash, job_id, is_job_schedule, entity)))
|
|
if len(tasks) > 0:
|
|
await asyncio.wait(tasks)
|
|
else:
|
|
logger.error(
|
|
'cannot {} {} {} for fed {}, no location entities '
|
|
'exist (uid={})'.format(
|
|
'delete' if delete else 'terminate',
|
|
'job schedule' if is_job_schedule else 'job',
|
|
job_id, fedhash, unique_id))
|
|
|
|
async def process_message_action_v1(
|
|
self,
|
|
fedhash: str,
|
|
data: Dict[str, Any],
|
|
unique_id: str
|
|
) -> bool:
|
|
result = True
|
|
# check proper version
|
|
if is_not_empty(data) and data['version'] != '1':
|
|
logger.error('cannot process job data version {} for {}'.format(
|
|
data['version'], unique_id))
|
|
return result
|
|
# extract data from message
|
|
action = data['action']['method']
|
|
target_type = data['action']['kind']
|
|
target = data[target_type]['id']
|
|
logger.debug(
|
|
'uid {} for fed {} message action={} target_type={} '
|
|
'target={}'.format(
|
|
unique_id, fedhash, action, target_type, target))
|
|
# take action depending upon kind and method
|
|
if target_type == 'job_schedule':
|
|
if action == 'add':
|
|
job_schedule = data[target_type]['data']
|
|
logger.debug(
|
|
'uid {} for fed {} target_type={} target={} '
|
|
'constraints={}'.format(
|
|
unique_id, fedhash, target_type, target,
|
|
data[target_type]['constraints']))
|
|
constraints = Constraints(data[target_type]['constraints'])
|
|
result = await self.add_job_schedule_v1(
|
|
fedhash, job_schedule, constraints, unique_id)
|
|
elif action == 'terminate':
|
|
await self.delete_or_terminate_job_v1(
|
|
False, fedhash, target, True, unique_id)
|
|
elif action == 'delete':
|
|
await self.delete_or_terminate_job_v1(
|
|
True, fedhash, target, True, unique_id)
|
|
else:
|
|
raise NotImplementedError()
|
|
elif target_type == 'job':
|
|
if action == 'add':
|
|
job = data[target_type]['data']
|
|
logger.debug(
|
|
'uid {} for fed {} target_type={} target={} '
|
|
'constraints={}'.format(
|
|
unique_id, fedhash, target_type, target,
|
|
data[target_type]['constraints']))
|
|
constraints = Constraints(data[target_type]['constraints'])
|
|
logger.debug(
|
|
'uid {} for fed {} target_type={} target={} '
|
|
'naming={}'.format(
|
|
unique_id, fedhash, target_type, target,
|
|
data[target_type]['task_naming']))
|
|
naming = TaskNaming(data[target_type]['task_naming'])
|
|
task_map = data['task_map']
|
|
result = await self.add_job_v1(
|
|
fedhash, job, constraints, naming, task_map, unique_id)
|
|
elif action == 'terminate':
|
|
await self.delete_or_terminate_job_v1(
|
|
False, fedhash, target, False, unique_id)
|
|
elif action == 'delete':
|
|
await self.delete_or_terminate_job_v1(
|
|
True, fedhash, target, False, unique_id)
|
|
else:
|
|
raise NotImplementedError()
|
|
else:
|
|
logger.error('unknown target type: {}'.format(target_type))
|
|
return result
|
|
|
|
async def process_queue_message_v1(
|
|
self,
|
|
fedhash: str,
|
|
msg: Dict[str, Any]
|
|
) -> Tuple[bool, str]:
|
|
result = True
|
|
target_fedid = msg['federation_id']
|
|
calc_fedhash = hash_federation_id(target_fedid)
|
|
if calc_fedhash != fedhash:
|
|
logger.error(
|
|
'federation hash mismatch, expected={} actual={} id={}'.format(
|
|
fedhash, calc_fedhash, target_fedid))
|
|
return result, None
|
|
target = msg['target']
|
|
unique_id = msg['uuid']
|
|
# get sequence from table
|
|
seq_id = self.fdh.get_first_sequence_id_for_job(fedhash, target)
|
|
if seq_id is None:
|
|
logger.error(
|
|
'sequence length is missing or non-positive for uid={} for '
|
|
'target {} on federation {}'.format(
|
|
unique_id, target, fedhash))
|
|
# remove blocked action if any
|
|
self.fdh.remove_blocked_action_for_job(fedhash, target)
|
|
return result, None
|
|
# if there is a sequence mismatch, then queue is no longer FIFO
|
|
# get the appropriate next sequence id and construct the blob url
|
|
# for the message data
|
|
if seq_id != unique_id:
|
|
logger.warning(
|
|
'queue message for fed {} does not match first '
|
|
'sequence q:{} != t:{} for target {}'.format(
|
|
fedhash, unique_id, seq_id, target))
|
|
unique_id = seq_id
|
|
blob_url = self.fdh.construct_blob_url(fedhash, unique_id)
|
|
else:
|
|
blob_url = msg['blob_data']
|
|
del seq_id
|
|
# retrieve message data from blob
|
|
job_data = None
|
|
try:
|
|
blob_client, container, blob_name, data = \
|
|
self.fdh.retrieve_blob_data(blob_url)
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
logger.error(
|
|
'cannot process queue message for sequence id {} for '
|
|
'fed {}'.format(unique_id, fedhash))
|
|
# remove blocked action if any
|
|
self.fdh.remove_blocked_action_for_job(fedhash, target)
|
|
return False, target
|
|
else:
|
|
job_data = pickle.loads(data, fix_imports=True)
|
|
del data
|
|
del blob_url
|
|
# process message
|
|
if job_data is not None:
|
|
result = await self.process_message_action_v1(
|
|
fedhash, job_data, unique_id)
|
|
# cleanup
|
|
if result:
|
|
self.fdh.delete_blob(blob_client, container, blob_name)
|
|
else:
|
|
target = None
|
|
return result, target
|
|
|
|
async def process_federation_queue(self, fedhash: str) -> None:
|
|
acquired = self.federations[fedhash].lock.acquire(blocking=False)
|
|
if not acquired:
|
|
logger.debug('could not acquire lock on federation {}'.format(
|
|
fedhash))
|
|
return
|
|
try:
|
|
msgs = self.fdh.get_messages_from_federation_queue(fedhash)
|
|
for msg in msgs:
|
|
if not await self.check_global_lock(backoff=False):
|
|
logger.error(
|
|
'global lock lease lost while processing queue for '
|
|
'fed {}'.format(fedhash))
|
|
return
|
|
msg_data = json.loads(msg.content, encoding='utf8')
|
|
if msg_data['version'] == '1':
|
|
del_msg, target = await self.process_queue_message_v1(
|
|
fedhash, msg_data)
|
|
else:
|
|
logger.error(
|
|
'cannot process message version {} for fed {}'.format(
|
|
msg_data['version'], fedhash))
|
|
del_msg = True
|
|
target = None
|
|
# delete message
|
|
self.fdh.dequeue_sequence_id_from_federation_sequence(
|
|
del_msg, fedhash, msg.id, msg.pop_receipt, target)
|
|
finally:
|
|
self.federations[fedhash].lock.release()
|
|
|
|
async def check_global_lock(
|
|
self,
|
|
backoff: bool = True
|
|
) -> Generator[None, None, None]:
|
|
if not self.fdh.has_global_lock:
|
|
if backoff:
|
|
await asyncio.sleep(5 + random.randint(0, 5))
|
|
return False
|
|
return True
|
|
|
|
async def iterate_and_process_federation_queues(
|
|
self
|
|
) -> Generator[None, None, None]:
|
|
while True:
|
|
if not await self.check_global_lock():
|
|
continue
|
|
if self.federations_available:
|
|
# TODO process in parallel
|
|
for fedhash in self.federations:
|
|
try:
|
|
await self.process_federation_queue(fedhash)
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
if not await self.check_global_lock(backoff=False):
|
|
break
|
|
await asyncio.sleep(self.action_refresh_interval)
|
|
|
|
async def poll_for_federations(
|
|
self,
|
|
loop: asyncio.BaseEventLoop,
|
|
) -> Generator[None, None, None]:
|
|
"""Poll federations
|
|
:param loop: asyncio loop
|
|
"""
|
|
# lease global lock blob
|
|
self.fdh.lease_global_lock(loop)
|
|
# block until global lock acquired
|
|
while not await self.check_global_lock():
|
|
pass
|
|
# mount log storage
|
|
log_path = self.fdh.mount_file_storage()
|
|
# set logging configuration
|
|
self.fdh.set_log_configuration(log_path)
|
|
self._service_proxy.log_configuration()
|
|
logger.debug('polling federation table {} every {} sec'.format(
|
|
self._service_proxy.table_name_global, self.fed_refresh_interval))
|
|
logger.debug('polling action queues {} every {} sec'.format(
|
|
self._service_proxy.table_name_jobs, self.action_refresh_interval))
|
|
# begin message processing
|
|
asyncio.ensure_future(
|
|
self.iterate_and_process_federation_queues(), loop=loop)
|
|
# continuously update federations
|
|
while True:
|
|
if not await self.check_global_lock():
|
|
continue
|
|
try:
|
|
self.update_federations()
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
await asyncio.sleep(self.fed_refresh_interval)
|
|
|
|
|
|
def main() -> None:
|
|
"""Main function"""
|
|
# get command-line args
|
|
args = parseargs()
|
|
# load configuration
|
|
if is_none_or_empty(args.conf):
|
|
raise ValueError('config file not specified')
|
|
with open(args.conf, 'rb') as f:
|
|
config = json.load(f)
|
|
logger.debug('loaded config from {}: {}'.format(args.conf, config))
|
|
del args
|
|
# create federation processor
|
|
fed_processor = FederationProcessor(config)
|
|
# run the poller
|
|
loop = asyncio.get_event_loop()
|
|
try:
|
|
loop.run_until_complete(
|
|
fed_processor.poll_for_federations(loop)
|
|
)
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
finally:
|
|
handlers = logger.handlers[:]
|
|
for handler in handlers:
|
|
handler.close()
|
|
logger.removeHandler(handler)
|
|
try:
|
|
fed_processor.fdh.unmount_file_storage()
|
|
except Exception as exc:
|
|
logger.exception(str(exc))
|
|
|
|
|
|
def parseargs() -> argparse.Namespace:
|
|
"""Parse program arguments
|
|
:return: parsed arguments
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description='federation: Azure Batch Shipyard Federation Controller')
|
|
parser.add_argument('--conf', help='configuration file')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
_setup_logger(logger)
|
|
az_logger = logging.getLogger('azure.storage')
|
|
_setup_logger(az_logger)
|
|
az_logger.setLevel(logging.WARNING)
|
|
az_logger = logging.getLogger('azure.cosmosdb')
|
|
_setup_logger(az_logger)
|
|
az_logger.setLevel(logging.WARNING)
|
|
main()
|