#!/usr/bin/env python3 # Copyright (c) Microsoft Corporation # # All rights reserved. # # MIT License # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. # stdlib imports import argparse import collections import concurrent.futures import datetime import enum import hashlib import json import logging import logging.handlers import multiprocessing import pathlib import random import re import subprocess import sys import threading import time from typing import ( Any, Dict, List, Optional, Tuple, ) # non-stdlib imports import azure.batch import azure.batch.models as batchmodels import azure.common import azure.cosmosdb.table import azure.mgmt.resource import azure.mgmt.storage import azure.storage.queue import dateutil.tz import msrestazure.azure_active_directory import msrestazure.azure_cloud # create logger logger = logging.getLogger(__name__) # global defines # TODO allow these maximums to be configurable _MAX_EXECUTOR_WORKERS = min((multiprocessing.cpu_count() * 4, 32)) _MAX_AUTH_FAILURE_RETRIES = 10 _MAX_RESUME_FAILURE_ATTEMPTS = 10 class Actions(enum.IntEnum): Suspend = 0, Resume = 1, ResumeFailed = 2, WaitForResume = 3, class HostStates(enum.IntEnum): Up = 0, Resuming = 1, ProvisionInterrupt = 2, Provisioned = 3, Suspended = 4, def setup_logger(log) -> None: """Set up logger""" log.setLevel(logging.DEBUG) handler = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s %(process)d %(levelname)s ' '%(name)s:%(funcName)s:%(lineno)d %(message)s' ) formatter.default_msec_format = '%s.%03d' handler.setFormatter(formatter) log.addHandler(handler) def max_workers_for_executor(iterable: Any) -> int: """Get max number of workers for executor given an iterable :param iterable: an iterable :return: number of workers for executor """ return min((len(iterable), _MAX_EXECUTOR_WORKERS)) def is_none_or_empty(obj: Any) -> bool: """Determine if object is None or empty :param obj: object :return: if object is None or empty """ return obj is None or len(obj) == 0 def is_not_empty(obj: Any) -> bool: """Determine if object is not None and is length is > 0 :param obj: object :return: if object is not None and length is > 0 """ return obj is not None and len(obj) > 0 def datetime_utcnow(as_string: bool = False) -> datetime.datetime: """Returns a datetime now with UTC timezone :param as_string: return as ISO8601 extended string :return: datetime object representing now with UTC timezone """ dt = datetime.datetime.now(dateutil.tz.tzutc()) if as_string: return dt.strftime('%Y%m%dT%H%M%S.%f')[:-3] + 'Z' else: return dt def hash_string(strdata: str) -> str: """Hash a string :param strdata: string data to hash :return: hexdigest """ return hashlib.sha1(strdata.encode('utf8')).hexdigest() def random_blocking_sleep(min: int, max: int) -> None: time.sleep(random.randint(min, max)) class Credentials(): def __init__(self, config: Dict[str, Any]) -> None: """Ctor for Credentials :param config: configuration """ # set attr from config self.storage_account = config['storage']['account'] try: self.storage_account_key = config['storage']['account_key'] self.storage_account_ep = config['storage']['endpoint'] self.storage_account_rg = None self.cloud = None self.arm_creds = None self.batch_creds = None self.sub_id = None logger.debug('storage account {} ep: {}'.format( self.storage_account, self.storage_account_ep)) except KeyError: self.storage_account_rg = config['storage']['resource_group'] # get cloud object self.cloud = Credentials.convert_cloud_type(config['aad_cloud']) # get aad creds self.arm_creds = self.create_msi_credentials() self.batch_creds = self.create_msi_credentials( resource_id=self.cloud.endpoints.batch_resource_id) # get subscription id self.sub_id = self.get_subscription_id() logger.debug('created msi auth for sub id: {}'.format(self.sub_id)) # get storage account key and endpoint self.storage_account_key, self.storage_account_ep = \ self.get_storage_account_key() logger.debug('storage account {} -> rg: {} ep: {}'.format( self.storage_account, self.storage_account_rg, self.storage_account_ep)) @staticmethod def convert_cloud_type(cloud_type: str) -> msrestazure.azure_cloud.Cloud: """Convert clout type string to object :param cloud_type: cloud type to convert :return: cloud object """ if cloud_type == 'public': cloud = msrestazure.azure_cloud.AZURE_PUBLIC_CLOUD elif cloud_type == 'china': cloud = msrestazure.azure_cloud.AZURE_CHINA_CLOUD elif cloud_type == 'germany': cloud = msrestazure.azure_cloud.AZURE_GERMAN_CLOUD elif cloud_type == 'usgov': cloud = msrestazure.azure_cloud.AZURE_US_GOV_CLOUD else: raise ValueError('unknown cloud_type: {}'.format(cloud_type)) return cloud def get_subscription_id(self) -> str: """Get subscription id for ARM creds :param arm_creds: ARM creds :return: subscription id """ client = azure.mgmt.resource.SubscriptionClient(self.arm_creds) return next(client.subscriptions.list()).subscription_id def create_msi_credentials( self, resource_id: str = None ) -> msrestazure.azure_active_directory.MSIAuthentication: """Create MSI credentials :param resource_id: resource id to auth against :return: MSI auth object """ if is_not_empty(resource_id): creds = msrestazure.azure_active_directory.MSIAuthentication( cloud_environment=self.cloud, resource=resource_id, ) else: creds = msrestazure.azure_active_directory.MSIAuthentication( cloud_environment=self.cloud, ) return creds def get_storage_account_key(self) -> Tuple[str, str]: """Retrieve the storage account key and endpoint :return: tuple of key, endpoint """ client = azure.mgmt.storage.StorageManagementClient( self.arm_creds, self.sub_id, base_url=self.cloud.endpoints.resource_manager) ep = None if is_not_empty(self.storage_account_rg): acct = client.storage_accounts.get_properties( self.storage_account_rg, self.storage_account) ep = '.'.join( acct.primary_endpoints.blob.rstrip('/').split('.')[2:] ) else: for acct in client.storage_accounts.list(): if acct.name == self.storage_account: self.storage_account_rg = acct.id.split('/')[4] ep = '.'.join( acct.primary_endpoints.blob.rstrip('/').split('.')[2:] ) break if is_none_or_empty(self.storage_account_rg) or is_none_or_empty(ep): raise RuntimeError( 'storage account {} not found in subscription id {}'.format( self.storage_account, self.sub_id)) keys = client.storage_accounts.list_keys( self.storage_account_rg, self.storage_account) return (keys.keys[0].value, ep) class ServiceProxy(): def __init__(self, config: Dict[str, Any]) -> None: """Ctor for ServiceProxy :param config: configuration """ self._config = config self._resume_timeout = None self._suspend_timeout = None prefix = config['storage']['entity_prefix'] self.cluster_id = config['cluster_id'] self.logging_id = config['logging_id'] self.table_name = '{}slurm'.format(prefix) try: self.queue_assign = config['storage']['queues']['assign'] except KeyError: self.queue_assign = None try: self.queue_action = config['storage']['queues']['action'] except KeyError: self.queue_action = None try: self.node_id = config['batch']['node_id'] self.pool_id = config['batch']['pool_id'] self.ip_address = config['ip_address'] except KeyError: self.node_id = None self.pool_id = None self.ip_address = None self.file_share_hmp = pathlib.Path( config['storage']['azfile_mount_dir']) / config['cluster_id'] self._batch_client_lock = threading.Lock() self.batch_clients = {} # create credentials self.creds = Credentials(config) # create clients self.table_client = self._create_table_client() self.queue_client = self._create_queue_client() logger.debug('created storage clients for storage account {}'.format( self.creds.storage_account)) @property def batch_shipyard_version(self) -> str: return self._config['batch_shipyard']['version'] @property def batch_shipyard_var_path(self) -> pathlib.Path: return pathlib.Path(self._config['batch_shipyard']['var_path']) @property def storage_entity_prefix(self) -> str: return self._config['storage']['entity_prefix'] @property def resume_timeout(self) -> int: if self._resume_timeout is None: # subtract off 5 seconds for fudge val = self._config['timeouts']['resume'] - 5 if val < 5: val = 5 self._resume_timeout = val return self._resume_timeout @property def suspend_timeout(self) -> int: if self._suspend_timeout is None: # subtract off 5 seconds for fudge val = self._config['timeouts']['suspend'] - 5 if val < 5: val = 5 self._suspend_timeout = val return self._suspend_timeout def log_configuration(self) -> None: logger.debug('configuration: {}'.format( json.dumps(self._config, sort_keys=True, indent=4))) def _modify_client_for_retry_and_user_agent(self, client: Any) -> None: """Extend retry policy of clients and add user agent string :param client: a client object """ if client is None: return client.config.retry_policy.max_backoff = 8 client.config.retry_policy.retries = 100 client.config.add_user_agent('batch-shipyard/{}'.format( self.batch_shipyard_version)) def _create_table_client(self) -> azure.cosmosdb.table.TableService: """Create a table client for the given storage account :return: table client """ client = azure.cosmosdb.table.TableService( account_name=self.creds.storage_account, account_key=self.creds.storage_account_key, endpoint_suffix=self.creds.storage_account_ep, ) return client def _create_queue_client(self) -> azure.storage.queue.QueueService: """Create a queue client for the given storage account :return: queue client """ client = azure.storage.queue.QueueService( account_name=self.creds.storage_account, account_key=self.creds.storage_account_key, endpoint_suffix=self.creds.storage_account_ep, ) return client def batch_client( self, service_url: str ) -> azure.batch.BatchServiceClient: """Get/create batch client :param service_url: service url :return: batch client """ with self._batch_client_lock: try: return self.batch_clients[service_url] except KeyError: client = azure.batch.BatchServiceClient( self.creds.batch_creds, batch_url=service_url) self._modify_client_for_retry_and_user_agent(client) self.batch_clients[service_url] = client logger.debug('batch client created for account: {}'.format( service_url)) return client def reset_batch_creds_and_client( self, service_url: str ) -> None: logger.warning('resetting batch creds and client for {}'.format( service_url)) with self._batch_client_lock: self.creds.batch_creds = self.create_msi_credentials( resource_id=self.cloud.endpoints.batch_resource_id) self.batch_clients.pop(service_url, None) class StorageServiceHandler(): _PARTITIONS_PREFIX = 'PARTITIONS' _HOSTS_PREFIX = 'HOSTS' def __init__(self, service_proxy: ServiceProxy) -> None: """Ctor for Storage handler :param service_proxy: ServiceProxy """ self.service_proxy = service_proxy def get_storage_account_key(self) -> Tuple[str, str]: return (self.service_proxy.creds.storage_account_key, self.service_proxy.creds.storage_account_ep) def list_partitions(self) -> List[azure.cosmosdb.table.Entity]: return self.service_proxy.table_client.query_entities( self.service_proxy.table_name, filter='PartitionKey eq \'{}${}\''.format( self._PARTITIONS_PREFIX, self.service_proxy.cluster_id)) def get_host_assignment_entity( self, host: str, ) -> azure.cosmosdb.table.Entity: return self.service_proxy.table_client.get_entity( self.service_proxy.table_name, '{}${}'.format(self._HOSTS_PREFIX, self.service_proxy.cluster_id), host) def delete_node_assignment_entity( self, entity: azure.cosmosdb.table.Entity, ) -> None: try: self.service_proxy.table_client.delete_entity( self.service_proxy.table_name, entity['PartitionKey'], entity['RowKey']) except azure.common.AzureMissingResourceHttpError: pass def insert_queue_assignment_msg(self, rowkey: str, host: str) -> None: rkparts = rowkey.split('$') qname = '{}-{}'.format(self.service_proxy.cluster_id, rkparts[1]) logger.debug('inserting host {} assignment token to queue {}'.format( host, qname)) msg = { 'cluster_id': self.service_proxy.cluster_id, 'host': host, } msg_data = json.dumps(msg, ensure_ascii=True, sort_keys=True) self.service_proxy.queue_client.put_message( qname, msg_data, time_to_live=-1) def get_queue_assignment_msg(self) -> None: logger.debug('getting queue assignment from {}'.format( self.service_proxy.queue_assign)) host = None while host is None: msgs = self.service_proxy.queue_client.get_messages( self.service_proxy.queue_assign, num_messages=1, visibility_timeout=150) for msg in msgs: msg_data = json.loads(msg.content, encoding='utf8') logger.debug( 'got message {}: {}'.format(msg.id, msg_data)) host = msg_data['host'] outfile = pathlib.Path( self.service_proxy.batch_shipyard_var_path) / 'slurm_host' with outfile.open('wt') as f: f.write(host) self.service_proxy.queue_client.delete_message( self.service_proxy.queue_assign, msg.id, msg.pop_receipt) break random_blocking_sleep(1, 3) logger.info('got host assignment: {}'.format(host)) def insert_queue_action_msg( self, action: Actions, hosts: List[str], retry_count: Optional[int] = None, visibility_timeout: Optional[int] = None, ) -> None: msg = { 'cluster_id': self.service_proxy.cluster_id, 'action': action.value, 'hosts': hosts, } if retry_count is not None: msg['retry_count'] = retry_count logger.debug('inserting queue {} message (vt={}): {}'.format( self.service_proxy.queue_action, visibility_timeout, msg)) msg_data = json.dumps(msg, ensure_ascii=True, sort_keys=True) self.service_proxy.queue_client.put_message( self.service_proxy.queue_action, msg_data, visibility_timeout=visibility_timeout, time_to_live=-1) def get_queue_action_msg( self ) -> Optional[Tuple[Dict[str, Any], str, str]]: msgs = self.service_proxy.queue_client.get_messages( self.service_proxy.queue_action, num_messages=1, visibility_timeout=self.service_proxy.resume_timeout) for msg in msgs: msg_data = json.loads(msg.content, encoding='utf8') logger.debug( 'got message {} from queue {}: {}'.format( msg.id, self.service_proxy.queue_action, msg_data)) return (msg_data, msg.id, msg.pop_receipt) return None def update_queue_action_msg( self, id: str, pop_receipt: str, ) -> None: logger.debug( 'updating queue {} message id {} pop receipt {}'.format( self.service_proxy.queue_action, id, pop_receipt)) self.service_proxy.queue_client.update_message( self.service_proxy.queue_action, id, pop_receipt, 20) def delete_queue_action_msg( self, id: str, pop_receipt: str, ) -> None: logger.debug( 'deleting queue {} message id {} pop receipt {}'.format( self.service_proxy.queue_action, id, pop_receipt)) self.service_proxy.queue_client.delete_message( self.service_proxy.queue_action, id, pop_receipt) def insert_host_assignment_entity( self, host: str, partition_name: str, service_url: str, pool_id: str ) -> None: entity = { 'PartitionKey': '{}${}'.format( self._HOSTS_PREFIX, self.service_proxy.cluster_id), 'RowKey': host, 'Partition': partition_name, 'State': HostStates.Resuming.value, 'BatchServiceUrl': service_url, 'BatchPoolId': pool_id, 'BatchShipyardSlurmVersion': 1, } self.service_proxy.table_client.insert_or_replace_entity( self.service_proxy.table_name, entity) def merge_host_assignment_entity_for_compute_node( self, host: str, state: HostStates, state_only: bool, retry_on_conflict: Optional[bool] = None, ) -> None: entity = { 'PartitionKey': '{}${}'.format( self._HOSTS_PREFIX, self.service_proxy.cluster_id), 'RowKey': host, 'State': state.value, } if not state_only: entity['BatchNodeId'] = self.service_proxy.node_id logger.debug( 'merging host {} ip={} assignment entity in table {}: {}'.format( host, self.service_proxy.ip_address, self.service_proxy.table_name, entity)) if retry_on_conflict is None: retry_on_conflict = True while True: try: self.service_proxy.table_client.merge_entity( self.service_proxy.table_name, entity) break except azure.common.AzureConflictHttpError: if retry_on_conflict: random_blocking_sleep(1, 3) else: raise def update_host_assignment_entity_as_provisioned( self, host: str, entity: azure.cosmosdb.table.Entity, ) -> None: entity['State'] = HostStates.Provisioned.value entity['IpAddress'] = self.service_proxy.ip_address entity['BatchNodeId'] = self.service_proxy.node_id logger.debug( 'updating host {} ip={} assignment entity in table {}: {}'.format( host, self.service_proxy.ip_address, self.service_proxy.table_name, entity)) # this must process sucessfully with no etag collision self.service_proxy.table_client.update_entity( self.service_proxy.table_name, entity) def wait_for_host_assignment_entities( self, start_time: datetime.datetime, hosts: List[str], timeout: Optional[int] = None, set_idle_state: Optional[bool] = None, ) -> None: if timeout is None: timeout = self.service_proxy.resume_timeout logger.info('waiting for {} hosts to spin up in {} sec'.format( len(hosts), timeout)) host_queue = collections.deque(hosts) i = 0 while len(host_queue) > 0: host = host_queue.popleft() try: entity = self.get_host_assignment_entity(host) ip = entity['IpAddress'] state = HostStates(entity['State']) if state != HostStates.Provisioned: logger.error('unexpected state for host {}: {}'.format( host, state)) raise KeyError() except (azure.common.AzureMissingResourceHttpError, KeyError): host_queue.append(host) else: logger.debug( 'updating host {} with ip {} node id {} pool id {}'.format( host, ip, entity['BatchNodeId'], entity['BatchPoolId'])) cmd = ['scontrol', 'update', 'NodeName={}'.format(host), 'NodeAddr={}'.format(ip), 'NodeHostname={}'.format(host)] if set_idle_state: cmd.append('State=Idle') logger.debug('command: {}'.format(' '.join(cmd))) subprocess.check_call(cmd) # update entity state m_entity = { 'PartitionKey': entity['PartitionKey'], 'RowKey': entity['RowKey'], 'State': HostStates.Up.value, } self.service_proxy.table_client.merge_entity( self.service_proxy.table_name, m_entity) continue i += 1 if i % 6 == 0: i = 0 logger.debug('still waiting for {} hosts'.format( len(host_queue))) diff = datetime_utcnow() - start_time if diff.total_seconds() > timeout: return host_queue random_blocking_sleep(5, 10) logger.info('{} host spin up completed'.format(len(hosts))) return None class BatchServiceHandler(): def __init__(self, service_proxy: ServiceProxy) -> None: """Ctor for Batch handler :param service_proxy: ServiceProxy """ self.service_proxy = service_proxy def get_node_state_counts( self, service_url: str, pool_id: str, ) -> batchmodels.PoolNodeCounts: client = self.service_proxy.batch_client(service_url) try: node_counts = client.account.list_pool_node_counts( account_list_pool_node_counts_options=batchmodels. AccountListPoolNodeCountsOptions( filter='poolId eq \'{}\''.format(pool_id) ) ) nc = list(node_counts) if len(nc) == 0: logger.error( 'no node counts for pool {} (service_url={})'.format( pool_id, service_url)) return nc[0] except batchmodels.BatchErrorException: logger.error( 'could not retrieve pool {} node counts ' '(service_url={})'.format(pool_id, service_url)) def get_node_info( self, service_url: str, pool_id: str, node_id: str, ) -> batchmodels.ComputeNode: auth_attempts = 0 client = self.service_proxy.batch_client(service_url) # get pool and add delta to current target counts while True: try: return client.compute_node.get(pool_id, node_id) except batchmodels.BatchErrorException as ex: if 'failed to authorize the request' in ex.message.value: if auth_attempts > _MAX_AUTH_FAILURE_RETRIES: raise logger.warning( 'authorization failed for {}, retrying'.format( service_url)) self.service_proxy.reset_batch_creds_and_client( service_url) random_blocking_sleep(1, 3) auth_attempts += 1 client = self.service_proxy.batch_client(service_url) else: return None def add_nodes_to_pool( self, service_url: str, pool_id: str, compute_node_type: str, num_hosts: int, ) -> None: auth_attempts = 0 client = self.service_proxy.batch_client(service_url) # get pool and add delta to current target counts while True: try: pool = client.pool.get(pool_id) if (pool.allocation_state == batchmodels.AllocationState.resizing): logger.debug( 'cannot add nodes to pool {} as it is resizing, ' 'will retry'.format(pool_id)) random_blocking_sleep(5, 10) continue if compute_node_type == 'dedicated': target_dedicated = pool.target_dedicated_nodes + num_hosts target_low_priority = 0 logger.debug( 'adding dedicated nodes to pool {}: {} -> {} ' '(service_url={})'.format( pool_id, pool.target_dedicated_nodes, target_dedicated, service_url)) else: target_dedicated = 0 target_low_priority = ( pool.target_low_priority_nodes + num_hosts ) logger.debug( 'adding low priority nodes to pool {}: {} -> {} ' '(service_url={})'.format( pool_id, pool.target_low_priority_nodes, target_low_priority, service_url)) client.pool.resize( pool_id, pool_resize_parameter=batchmodels.PoolResizeParameter( target_dedicated_nodes=target_dedicated, target_low_priority_nodes=target_low_priority, ), ) logger.info('added nodes to pool {}'.format(pool_id)) break except batchmodels.BatchErrorException as ex: if 'ongoing resize operation' in ex.message.value: logger.debug('pool {} is resizing, will retry'.format( pool_id)) random_blocking_sleep(5, 10) elif 'failed to authorize the request' in ex.message.value: if auth_attempts > _MAX_AUTH_FAILURE_RETRIES: raise logger.warning( 'authorization failed for {}, retrying'.format( service_url)) self.service_proxy.reset_batch_creds_and_client( service_url) random_blocking_sleep(1, 3) auth_attempts += 1 client = self.service_proxy.batch_client(service_url) else: logger.exception( 'could not add nodes to pool {} ' '(service_url={})'.format(pool_id, service_url)) raise def remove_nodes_from_pool( self, service_url: str, pool_id: str, nodes: List[str], ) -> None: auth_attempts = 0 client = self.service_proxy.batch_client(service_url) while True: try: pool = client.pool.get(pool_id) if (pool.allocation_state == batchmodels.AllocationState.resizing): logger.debug( 'cannot remove nodes to pool {} as it is resizing, ' 'will retry'.format(pool_id)) random_blocking_sleep(5, 10) continue client.pool.remove_nodes( pool_id, node_remove_parameter=batchmodels.NodeRemoveParameter( node_list=nodes, ), ) logger.info('removed {} nodes from pool {}'.format( len(nodes), pool_id)) break except batchmodels.BatchErrorException as ex: if 'ongoing resize operation' in ex.message.value: logger.debug('pool {} has ongoing resize operation'.format( pool_id)) random_blocking_sleep(5, 10) elif 'failed to authorize the request' in ex.message.value: if auth_attempts > _MAX_AUTH_FAILURE_RETRIES: # TODO need better recovery - requeue suspend action? raise logger.warning( 'authorization failed for {}, retrying'.format( service_url)) self.service_proxy.reset_batch_creds_and_client( service_url) random_blocking_sleep(1, 3) auth_attempts += 1 client = self.service_proxy.batch_client(service_url) else: logger.error( 'could not remove nodes from pool {} ' '(service_url={})'.format(pool_id, service_url)) # TODO need better recovery - requeue suspend action? break # delete log files for node in nodes: file = pathlib.Path( self.service_proxy.file_share_hmp ) / 'slurm' / 'logs' / 'slurm-helper-debug-{}.log'.format(node) try: file.unlink() except OSError: pass def clean_pool( self, service_url: str, pool_id: str, ) -> None: auth_attempts = 0 node_filter = [ '(state eq \'starttaskfailed\')', '(state eq \'unusable\')', '(state eq \'preempted\')', ] client = self.service_proxy.batch_client(service_url) while True: try: nodes = client.compute_node.list( pool_id=pool_id, compute_node_list_options=batchmodels. ComputeNodeListOptions(filter=' or '.join(node_filter)), ) node_ids = [node.id for node in nodes] if is_none_or_empty(node_ids): logger.debug('no nodes to clean from pool: {}'.format( pool_id)) return logger.info('removing nodes {} from pool {}'.format( node_ids, pool_id)) client.pool.remove_nodes( pool_id=pool_id, node_remove_parameter=batchmodels.NodeRemoveParameter( node_list=node_ids, ) ) break except batchmodels.BatchErrorException as ex: if 'ongoing resize operation' in ex.message.value: logger.debug('pool {} has ongoing resize operation'.format( pool_id)) random_blocking_sleep(5, 10) elif 'failed to authorize the request' in ex.message.value: if auth_attempts > _MAX_AUTH_FAILURE_RETRIES: # TODO need better recovery - requeue suspend action? raise logger.warning( 'authorization failed for {}, retrying'.format( service_url)) self.service_proxy.reset_batch_creds_and_client( service_url) random_blocking_sleep(1, 3) auth_attempts += 1 client = self.service_proxy.batch_client(service_url) else: logger.error( 'could not clean pool {} (service_url={})'.format( pool_id, service_url)) # TODO need better recovery - requeue resume fail action? break class CommandProcessor(): def __init__(self, config: Dict[str, Any]) -> None: """Ctor for CommandProcessor :param config: configuration """ self._service_proxy = ServiceProxy(config) self._partitions = None self.ssh = StorageServiceHandler(self._service_proxy) self.bsh = BatchServiceHandler(self._service_proxy) @property def slurm_partitions(self) -> List[azure.cosmosdb.table.Entity]: if self._partitions is None: self._partitions = list(self.ssh.list_partitions()) for entity in self._partitions: entity['HostList'] = re.compile(entity['HostList']) return self._partitions def set_log_configuration(self) -> None: global logger # remove existing handlers handlers = logger.handlers[:] for handler in handlers: handler.close() logger.removeHandler(handler) # set level logger.setLevel(logging.DEBUG) # set formatter formatter = logging.Formatter( '%(asctime)s %(levelname)s %(name)s:%(funcName)s:%(lineno)d ' '%(message)s') formatter.default_msec_format = '%s.%03d' # set handlers handler_stream = logging.StreamHandler() handler_stream.setFormatter(formatter) logger.addHandler(handler_stream) az_storage_logger = logging.getLogger('azure.storage') az_storage_logger.setLevel(logging.WARNING) az_storage_logger.addHandler(handler_stream) az_cosmosdb_logger = logging.getLogger('azure.cosmosdb') az_cosmosdb_logger.setLevel(logging.WARNING) az_cosmosdb_logger.addHandler(handler_stream) # log to selected log level file logfname = pathlib.Path('slurm-helper.log') logfile = ( self._service_proxy.file_share_hmp / 'slurm' / 'logs' / ('{}-{}-{}{}').format( logfname.stem, 'debug', self._service_proxy.logging_id, logfname.suffix) ) logfile.parent.mkdir(exist_ok=True) handler_logfile = logging.handlers.RotatingFileHandler( str(logfile), maxBytes=33554432, backupCount=20000, encoding='utf-8') handler_logfile.setFormatter(formatter) logger.addHandler(handler_logfile) az_storage_logger.addHandler(handler_logfile) az_cosmosdb_logger.addHandler(handler_logfile) # dump configuration self._service_proxy.log_configuration() def process_resume_action(self, hosts: List[str]) -> None: if len(hosts) == 0: logger.error('host list is empty for resume') return logger.debug( 'pulled action resume for hosts: {}'.format(', '.join(hosts))) # first check if hosts have already resumed as this action can # be called multiple times for the same set of hosts due to # controller failover hosts_modified = [] for he in hosts: host, partname = he.split() try: entity = self.ssh.get_host_assignment_entity(host) state = HostStates(entity['State']) except azure.common.AzureMissingResourceHttpError: hosts_modified.append(he) else: logger.debug( 'host entry {} found for partition {} but state ' 'is {}'.format(host, partname, state)) if state != HostStates.Resuming and state != HostStates.Up: hosts_modified.append(he) hosts = hosts_modified del hosts_modified logger.debug('resuming hosts: {}'.format(', '.join(hosts))) if len(hosts) == 0: logger.error('modified host list is empty for resume') return # collate hosts into a map of batch pools -> host config pool_map = {} partitions = self.slurm_partitions for entity in partitions: key = '{}${}'.format( entity['BatchServiceUrl'], entity['BatchPoolId']) pool_map[key] = { 'num_hosts': 0, 'compute_node_type': entity['ComputeNodeType'], } # insert host assignment entity and message total_hosts = 0 for he in hosts: host, partname = he.split() for entity in partitions: if not entity['RowKey'].startswith(partname): continue if entity['HostList'].fullmatch(host): key = '{}${}'.format( entity['BatchServiceUrl'], entity['BatchPoolId']) pool_map[key]['num_hosts'] += 1 self.ssh.insert_host_assignment_entity( host, partname, entity['BatchServiceUrl'], entity['BatchPoolId']) self.ssh.insert_queue_assignment_msg( entity['RowKey'], host) total_hosts += 1 break if total_hosts != len(hosts): logger.error( 'total host {} to number of hosts to resume ' '{} mismatch'.format(total_hosts, len(hosts))) del total_hosts # resize batch pools to specified number of hosts resize_futures = {} with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers_for_executor(pool_map)) as executor: for key in pool_map: service_url, pool_id = key.split('$') resize_futures[key] = executor.submit( self.bsh.add_nodes_to_pool( service_url, pool_id, pool_map[key]['compute_node_type'], pool_map[key]['num_hosts']) ) def process_suspend_action(self, hosts: List[str]) -> bool: if len(hosts) == 0: logger.error('host list is empty for suspend') return True logger.debug('suspending hosts: {}'.format(', '.join(hosts))) # find pool/account mapping for node suspend_retry = set() suspended = [] pool_map = {} entities = [] for host in hosts: try: entity = self.ssh.get_host_assignment_entity(host) except azure.common.AzureMissingResourceHttpError: logger.error('host {} entity not found'.format(host)) continue logger.debug('found host {} mapping: {}'.format(host, entity)) if HostStates(entity['State']) == HostStates.Suspended: logger.error('host {} is already suspended'.format(host)) continue try: node_id = entity['BatchNodeId'] except KeyError: logger.error( 'host {} does not have a batch node id assigned'.format( host)) suspend_retry.add(host) continue key = '{}${}'.format( entity['BatchServiceUrl'], entity['BatchPoolId']) if key not in pool_map: pool_map[key] = [] pool_map[key].append(node_id) suspended.append(host) entities.append(entity) # resize batch pools down, deleting specified hosts if len(entities) == 0: logger.info('no hosts to suspend after analyzing host entities') else: # remove nodes with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers_for_executor( pool_map)) as executor: for key in pool_map: service_url, pool_id = key.split('$') executor.submit( self.bsh.remove_nodes_from_pool( service_url, pool_id, pool_map[key]) ) # mark entities suspended with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers_for_executor( suspended)) as executor: for host in suspended: executor.submit( self.ssh. merge_host_assignment_entity_for_compute_node( host, HostStates.Suspended, True, retry_on_conflict=True) ) # delete log files for host in suspended: file = pathlib.Path( self._service_proxy.file_share_hmp ) / 'slurm' / 'logs' / 'slurmd-{}.log'.format(host) try: file.unlink() except OSError: pass # re-enqueue suspend retry entries if len(suspend_retry) > 0: if set(hosts) == suspend_retry: logger.debug('host suspend list in is the same as retry') return False logger.debug('adding suspend action for {} hosts to retry'.format( len(suspend_retry))) self.ssh.insert_queue_action_msg( Actions.Suspend, list(suspend_retry)) return True def _query_node_state(self, host, entity): service_url = entity['BatchServiceUrl'] pool_id = entity['BatchPoolId'] node_id = None try: node_id = entity['BatchNodeId'] except KeyError: logger.debug('batch node id not present for host {}'.format( host)) return node_id, None node = self.bsh.get_node_info(service_url, pool_id, node_id) if node is None: logger.error( 'host {} compute node {} on pool {} (service_url={}) ' 'does not exist'.format( host, node_id, pool_id, service_url)) return node_id, None logger.debug( 'host node {} on pool {} is {} (service_url={})'.format( host, node_id, pool_id, node.state, service_url)) return node_id, node.state def process_resume_failed_action( self, hosts: List[str], retry_count: int, ) -> bool: if len(hosts) == 0: logger.error('host list is empty for resume failed') return True hosts_retry = set() hosts_update = set() hosts_check = {} clean_pools = set() # create pool map from partitions pool_map = {} partitions = self.slurm_partitions for entity in partitions: key = '{}${}'.format( entity['BatchServiceUrl'], entity['BatchPoolId']) pool_map[key] = { 'compute_node_type': entity['ComputeNodeType'], 'hosts_recover': set(), 'nodes_recover': set(), } # check host state for he in hosts: host, partname = he.split() try: entity = self.ssh.get_host_assignment_entity(host) host_state = HostStates(entity['State']) except azure.common.AzureMissingResourceHttpError: # TODO what else can we do here? logger.error('host {} entity not found'.format(host)) continue logger.debug('host {} state is {}'.format(host, host_state)) if host_state == HostStates.Up: # TODO if up, verify sinfo state # sinfo -h -n host -o "%t" hosts_update.add(host) else: hosts_check[he] = entity # check pool for each host to mark cleanup for he in hosts: host, partname = he.split() entity = hosts_check[he] service_url = entity['BatchServiceUrl'] pool_id = entity['BatchPoolId'] node_counts = self.bsh.get_node_state_counts(service_url, pool_id) num_bad_nodes = ( node_counts.dedicated.unusable + node_counts.dedicated.start_task_failed + node_counts.low_priority.unusable + node_counts.low_priority.start_task_failed ) # mark pool for cleanup for any unusable/start task failed if num_bad_nodes > 0: key = '{}${}'.format(service_url, pool_id) logger.debug('{} bad nodes found on {}'.format( num_bad_nodes, key)) clean_pools.add(key) # check each host on the Batch service for he in hosts: host, partname = he.split() entity = hosts_check[he] service_url = entity['BatchServiceUrl'] pool_id = entity['BatchPoolId'] key = '{}${}'.format(service_url, pool_id) node_id, node_state = self._query_node_state(host, entity) if (node_state == batchmodels.ComputeNodeState.idle or node_state == batchmodels.ComputeNodeState.offline or node_state == batchmodels.ComputeNodeState.running): hosts_update.add(host) elif (node_state == batchmodels.ComputeNodeState.start_task_failed or node_state == batchmodels.ComputeNodeState.unusable or node_state == batchmodels.ComputeNodeState.preempted or node_state == batchmodels.ComputeNodeState.leaving_pool): clean_pools.add(key) pool_map[key]['nodes_recover'].add(node_id) pool_map[key]['hosts_recover'].add(he) else: if retry_count >= _MAX_RESUME_FAILURE_ATTEMPTS: logger.debug( '{} on partition {} exceeded max resume failure retry ' 'attempts, recovering instead'.format(host, partname)) if node_id is not None: pool_map[key]['nodes_recover'].add(node_id) pool_map[key]['hosts_recover'].add(he) else: hosts_retry.add(he) del hosts_check # update hosts if len(hosts_update) > 0: self.ssh.insert_queue_action_msg( Actions.WaitForResume, list(hosts_update), retry_count=0) del hosts_update # clean pools for key in clean_pools: service_url, pool_id = key.split('$') self.bsh.clean_pool(service_url, pool_id) del clean_pools # recover hosts for key in pool_map: hosts_recover = pool_map[key]['hosts_recover'] hrlen = len(hosts_recover) if hrlen == 0: continue for he in hosts_recover: host, partname = he.split() self.ssh.merge_host_assignment_entity_for_compute_node( host, HostStates.ProvisionInterrupt, True, retry_on_conflict=True) nodes_recover = pool_map[key]['nodes_recover'] if len(nodes_recover) > 0: service_url, pool_id = key.split('$') self.bsh.remove_nodes_from_pool( service_url, pool_id, list(nodes_recover)) host_list = list(hosts_recover) self.ssh.insert_queue_action_msg(Actions.Resume, host_list) self.ssh.insert_queue_action_msg( Actions.WaitForResume, host_list, retry_count=0, visibility_timeout=60) # re-enqueue failed resume hosts if len(hosts_retry) > 0: logger.debug( 'adding resume failed action for {} hosts to retry'.format( len(hosts_retry))) self.ssh.insert_queue_action_msg( Actions.ResumeFailed, list(hosts_retry), retry_count=retry_count + 1, visibility_timeout=60) return True def process_wait_for_resume_action( self, hosts: List[str], retry_count: int, ) -> bool: if len(hosts) == 0: logger.error('host list is empty for resume failed') return True start_time = datetime_utcnow() remain_hosts = self.ssh.wait_for_host_assignment_entities( start_time, hosts, timeout=10, set_idle_state=True) if remain_hosts is not None: if retry_count > self._service_proxy.resume_timeout / 30: logger.error( 'not retrying host spin up completion for: {}'.format( remain_hosts)) self.ssh.insert_queue_action_msg( Actions.ResumeFailed, list(remain_hosts), retry_count=0, visibility_timeout=5) else: logger.warning( 'host spin up not completed: {}'.format(remain_hosts)) self.ssh.insert_queue_action_msg( Actions.WaitForResume, list(remain_hosts), retry_count=retry_count + 1, visibility_timeout=30) return True def resume_hosts(self, hosts: List[str]) -> None: # insert into action queue start_time = datetime_utcnow() logger.debug('received resume hosts: {}'.format(', '.join(hosts))) self.ssh.insert_queue_action_msg(Actions.Resume, hosts) # process resume completions and translate into scontrol bare_hosts = [he.split()[0] for he in hosts] remain_hosts = self.ssh.wait_for_host_assignment_entities( start_time, bare_hosts) if remain_hosts is not None: raise RuntimeError( 'exceeded resume timeout waiting for hosts to ' 'spin up: {}'.format(remain_hosts)) def resume_hosts_failed(self, hosts: List[str]) -> None: # insert into action queue logger.info('received resume failed hosts: {}'.format( ', '.join(hosts))) self.ssh.insert_queue_action_msg( Actions.ResumeFailed, hosts, retry_count=0, visibility_timeout=5) def suspend_hosts(self, hosts: List[str]) -> None: # insert into action queue logger.debug('received suspend hosts: {}'.format(', '.join(hosts))) self.ssh.insert_queue_action_msg(Actions.Suspend, hosts) def check_provisioning_status(self, host: str) -> None: logger.debug( 'checking for provisioning status for host {}'.format(host)) try: entity = self.ssh.get_host_assignment_entity(host) state = HostStates(entity['State']) except (azure.common.AzureMissingResourceHttpError, KeyError): # this should not happen, but fail in case it does logger.error('host assignment entity does not exist for {}'.format( host)) sys.exit(1) logger.info('host {} state property is {}'.format(host, state)) if state != HostStates.Resuming: logger.error( 'unexpected state, state is not {} for host {}'.format( HostStates.Resuming, host)) # update host entity assignment self.ssh.merge_host_assignment_entity_for_compute_node( host, HostStates.ProvisionInterrupt, False, retry_on_conflict=False) sys.exit(1) return entity def daemon_processor(self) -> None: # set logging config for daemon processor self.set_log_configuration() logger.info('daemon processor starting') while True: msg = self.ssh.get_queue_action_msg() if msg is None: random_blocking_sleep(1, 3) else: del_msg = True action = msg[0]['action'] hosts = msg[0]['hosts'] msg_id = msg[1] pop_receipt = msg[2] if action == Actions.Suspend: del_msg = self.process_suspend_action(hosts) elif action == Actions.Resume: self.process_resume_action(hosts) elif action == Actions.ResumeFailed: del_msg = self.process_resume_failed_action( hosts, msg[0]['retry_count']) elif action == Actions.WaitForResume: del_msg = self.process_wait_for_resume_action( hosts, msg[0]['retry_count']) else: logger.error('unknown action {} for hosts {}'.format( action, ', '.join(hosts))) if del_msg: self.ssh.delete_queue_action_msg(msg_id, pop_receipt) else: self.ssh.update_queue_action_msg(msg_id, pop_receipt) def execute( self, action: str, hosts: Optional[List[str]], host: Optional[str], ) -> None: """Execute action :param action: action to execute """ # process actions if action == 'daemon': self.daemon_processor() elif action == 'sakey': sakey = self.ssh.get_storage_account_key() print(sakey[0], sakey[1]) elif action == 'resume': self.resume_hosts(hosts) elif action == 'resume-fail': self.resume_hosts_failed(hosts) elif action == 'suspend': self.suspend_hosts(hosts) elif action == 'check-provisioning-status': self.check_provisioning_status(host) elif action == 'get-node-assignment': self.ssh.get_queue_assignment_msg() elif action == 'complete-node-assignment': entity = self.check_provisioning_status(host) self.ssh.update_host_assignment_entity_as_provisioned(host, entity) else: raise ValueError('unknown action to execute: {}'.format(action)) def main() -> None: """Main function""" # get command-line args args = parseargs() if is_none_or_empty(args.action): raise ValueError('action is invalid') # load configuration if is_none_or_empty(args.conf): raise ValueError('config file not specified') with open(args.conf, 'rb') as f: config = json.load(f) logger.debug('loaded config from {}: {}'.format(args.conf, config)) # parse hostfile if args.hostfile is not None: with open(args.hostfile, 'r') as f: hosts = [line.rstrip() for line in f] else: hosts = None try: # create command processor cmd_processor = CommandProcessor(config) # execute action cmd_processor.execute(args.action, hosts, args.host) except Exception: logger.exception('error executing {}'.format(args.action)) finally: handlers = logger.handlers[:] for handler in handlers: handler.close() logger.removeHandler(handler) def parseargs() -> argparse.Namespace: """Parse program arguments :return: parsed arguments """ parser = argparse.ArgumentParser( description='slurm: Azure Batch Shipyard Slurm Helper') parser.add_argument( 'action', choices=[ 'daemon', 'sakey', 'resume', 'resume-fail', 'suspend', 'check-provisioning-status', 'get-node-assignment', 'complete-node-assignment', ] ) parser.add_argument('--conf', help='configuration file') parser.add_argument('--hostfile', help='host file') parser.add_argument('--host', help='host') return parser.parse_args() if __name__ == '__main__': # set up log formatting and default handlers setup_logger(logger) main()