batch-shipyard/slurm/slurm.py

1473 строки
58 KiB
Python
Исходник Обычный вид История

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation
#
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# stdlib imports
import argparse
import collections
import concurrent.futures
import datetime
import enum
import hashlib
import json
import logging
import logging.handlers
import multiprocessing
import pathlib
import random
import re
import subprocess
import sys
import threading
import time
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
)
# non-stdlib imports
import azure.batch
import azure.batch.models as batchmodels
import azure.common
import azure.cosmosdb.table
import azure.mgmt.resource
import azure.mgmt.storage
import azure.storage.queue
import dateutil.tz
import msrestazure.azure_active_directory
import msrestazure.azure_cloud
# create logger
logger = logging.getLogger(__name__)
# global defines
# TODO allow these maximums to be configurable
_MAX_EXECUTOR_WORKERS = min((multiprocessing.cpu_count() * 4, 32))
_MAX_AUTH_FAILURE_RETRIES = 10
_MAX_RESUME_FAILURE_ATTEMPTS = 10
class Actions(enum.IntEnum):
Suspend = 0,
Resume = 1,
ResumeFailed = 2,
WaitForResume = 3,
class HostStates(enum.IntEnum):
Up = 0,
Resuming = 1,
ProvisionInterrupt = 2,
Provisioned = 3,
Suspended = 4,
def setup_logger(log) -> None:
"""Set up logger"""
log.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s %(process)d %(levelname)s '
'%(name)s:%(funcName)s:%(lineno)d %(message)s'
)
formatter.default_msec_format = '%s.%03d'
handler.setFormatter(formatter)
log.addHandler(handler)
def max_workers_for_executor(iterable: Any) -> int:
"""Get max number of workers for executor given an iterable
:param iterable: an iterable
:return: number of workers for executor
"""
return min((len(iterable), _MAX_EXECUTOR_WORKERS))
def is_none_or_empty(obj: Any) -> bool:
"""Determine if object is None or empty
:param obj: object
:return: if object is None or empty
"""
return obj is None or len(obj) == 0
def is_not_empty(obj: Any) -> bool:
"""Determine if object is not None and is length is > 0
:param obj: object
:return: if object is not None and length is > 0
"""
return obj is not None and len(obj) > 0
def datetime_utcnow(as_string: bool = False) -> datetime.datetime:
"""Returns a datetime now with UTC timezone
:param as_string: return as ISO8601 extended string
:return: datetime object representing now with UTC timezone
"""
dt = datetime.datetime.now(dateutil.tz.tzutc())
if as_string:
return dt.strftime('%Y%m%dT%H%M%S.%f')[:-3] + 'Z'
else:
return dt
def hash_string(strdata: str) -> str:
"""Hash a string
:param strdata: string data to hash
:return: hexdigest
"""
return hashlib.sha1(strdata.encode('utf8')).hexdigest()
def random_blocking_sleep(min: int, max: int) -> None:
time.sleep(random.randint(min, max))
class Credentials():
def __init__(self, config: Dict[str, Any]) -> None:
"""Ctor for Credentials
:param config: configuration
"""
# set attr from config
self.storage_account = config['storage']['account']
try:
self.storage_account_key = config['storage']['account_key']
self.storage_account_ep = config['storage']['endpoint']
self.storage_account_rg = None
self.cloud = None
self.arm_creds = None
self.batch_creds = None
self.sub_id = None
logger.debug('storage account {} ep: {}'.format(
self.storage_account, self.storage_account_ep))
except KeyError:
self.storage_account_rg = config['storage']['resource_group']
# get cloud object
self.cloud = Credentials.convert_cloud_type(config['aad_cloud'])
# get aad creds
self.arm_creds = self.create_msi_credentials()
self.batch_creds = self.create_msi_credentials(
resource_id=self.cloud.endpoints.batch_resource_id)
# get subscription id
self.sub_id = self.get_subscription_id()
logger.debug('created msi auth for sub id: {}'.format(self.sub_id))
# get storage account key and endpoint
self.storage_account_key, self.storage_account_ep = \
self.get_storage_account_key()
logger.debug('storage account {} -> rg: {} ep: {}'.format(
self.storage_account, self.storage_account_rg,
self.storage_account_ep))
@staticmethod
def convert_cloud_type(cloud_type: str) -> msrestazure.azure_cloud.Cloud:
"""Convert clout type string to object
:param cloud_type: cloud type to convert
:return: cloud object
"""
if cloud_type == 'public':
cloud = msrestazure.azure_cloud.AZURE_PUBLIC_CLOUD
elif cloud_type == 'china':
cloud = msrestazure.azure_cloud.AZURE_CHINA_CLOUD
elif cloud_type == 'germany':
cloud = msrestazure.azure_cloud.AZURE_GERMAN_CLOUD
elif cloud_type == 'usgov':
cloud = msrestazure.azure_cloud.AZURE_US_GOV_CLOUD
else:
raise ValueError('unknown cloud_type: {}'.format(cloud_type))
return cloud
def get_subscription_id(self) -> str:
"""Get subscription id for ARM creds
:param arm_creds: ARM creds
:return: subscription id
"""
client = azure.mgmt.resource.SubscriptionClient(self.arm_creds)
return next(client.subscriptions.list()).subscription_id
def create_msi_credentials(
self,
resource_id: str = None
) -> msrestazure.azure_active_directory.MSIAuthentication:
"""Create MSI credentials
:param resource_id: resource id to auth against
:return: MSI auth object
"""
if is_not_empty(resource_id):
creds = msrestazure.azure_active_directory.MSIAuthentication(
cloud_environment=self.cloud,
resource=resource_id,
)
else:
creds = msrestazure.azure_active_directory.MSIAuthentication(
cloud_environment=self.cloud,
)
return creds
def get_storage_account_key(self) -> Tuple[str, str]:
"""Retrieve the storage account key and endpoint
:return: tuple of key, endpoint
"""
client = azure.mgmt.storage.StorageManagementClient(
self.arm_creds, self.sub_id,
base_url=self.cloud.endpoints.resource_manager)
ep = None
if is_not_empty(self.storage_account_rg):
acct = client.storage_accounts.get_properties(
self.storage_account_rg, self.storage_account)
ep = '.'.join(
acct.primary_endpoints.blob.rstrip('/').split('.')[2:]
)
else:
for acct in client.storage_accounts.list():
if acct.name == self.storage_account:
self.storage_account_rg = acct.id.split('/')[4]
ep = '.'.join(
acct.primary_endpoints.blob.rstrip('/').split('.')[2:]
)
break
if is_none_or_empty(self.storage_account_rg) or is_none_or_empty(ep):
raise RuntimeError(
'storage account {} not found in subscription id {}'.format(
self.storage_account, self.sub_id))
keys = client.storage_accounts.list_keys(
self.storage_account_rg, self.storage_account)
return (keys.keys[0].value, ep)
class ServiceProxy():
def __init__(self, config: Dict[str, Any]) -> None:
"""Ctor for ServiceProxy
:param config: configuration
"""
self._config = config
self._resume_timeout = None
self._suspend_timeout = None
prefix = config['storage']['entity_prefix']
self.cluster_id = config['cluster_id']
self.logging_id = config['logging_id']
self.table_name = '{}slurm'.format(prefix)
try:
self.queue_assign = config['storage']['queues']['assign']
except KeyError:
self.queue_assign = None
try:
self.queue_action = config['storage']['queues']['action']
except KeyError:
self.queue_action = None
try:
self.node_id = config['batch']['node_id']
self.pool_id = config['batch']['pool_id']
self.ip_address = config['ip_address']
except KeyError:
self.node_id = None
self.pool_id = None
self.ip_address = None
self.file_share_hmp = pathlib.Path(
config['storage']['azfile_mount_dir']) / config['cluster_id']
self._batch_client_lock = threading.Lock()
self.batch_clients = {}
# create credentials
self.creds = Credentials(config)
# create clients
self.table_client = self._create_table_client()
self.queue_client = self._create_queue_client()
logger.debug('created storage clients for storage account {}'.format(
self.creds.storage_account))
@property
def batch_shipyard_version(self) -> str:
return self._config['batch_shipyard']['version']
@property
def batch_shipyard_var_path(self) -> pathlib.Path:
return pathlib.Path(self._config['batch_shipyard']['var_path'])
@property
def storage_entity_prefix(self) -> str:
return self._config['storage']['entity_prefix']
@property
def resume_timeout(self) -> int:
if self._resume_timeout is None:
# subtract off 5 seconds for fudge
val = self._config['timeouts']['resume'] - 5
if val < 5:
val = 5
self._resume_timeout = val
return self._resume_timeout
@property
def suspend_timeout(self) -> int:
if self._suspend_timeout is None:
# subtract off 5 seconds for fudge
val = self._config['timeouts']['suspend'] - 5
if val < 5:
val = 5
self._suspend_timeout = val
return self._suspend_timeout
def log_configuration(self) -> None:
logger.debug('configuration: {}'.format(
json.dumps(self._config, sort_keys=True, indent=4)))
def _modify_client_for_retry_and_user_agent(self, client: Any) -> None:
"""Extend retry policy of clients and add user agent string
:param client: a client object
"""
if client is None:
return
client.config.retry_policy.max_backoff = 8
client.config.retry_policy.retries = 100
client.config.add_user_agent('batch-shipyard/{}'.format(
self.batch_shipyard_version))
def _create_table_client(self) -> azure.cosmosdb.table.TableService:
"""Create a table client for the given storage account
:return: table client
"""
client = azure.cosmosdb.table.TableService(
account_name=self.creds.storage_account,
account_key=self.creds.storage_account_key,
endpoint_suffix=self.creds.storage_account_ep,
)
return client
def _create_queue_client(self) -> azure.storage.queue.QueueService:
"""Create a queue client for the given storage account
:return: queue client
"""
client = azure.storage.queue.QueueService(
account_name=self.creds.storage_account,
account_key=self.creds.storage_account_key,
endpoint_suffix=self.creds.storage_account_ep,
)
return client
def batch_client(
self,
service_url: str
) -> azure.batch.BatchServiceClient:
"""Get/create batch client
:param service_url: service url
:return: batch client
"""
with self._batch_client_lock:
try:
return self.batch_clients[service_url]
except KeyError:
client = azure.batch.BatchServiceClient(
self.creds.batch_creds, batch_url=service_url)
self._modify_client_for_retry_and_user_agent(client)
self.batch_clients[service_url] = client
logger.debug('batch client created for account: {}'.format(
service_url))
return client
def reset_batch_creds_and_client(
self,
service_url: str
) -> None:
logger.warning('resetting batch creds and client for {}'.format(
service_url))
with self._batch_client_lock:
self.creds.batch_creds = self.create_msi_credentials(
resource_id=self.cloud.endpoints.batch_resource_id)
self.batch_clients.pop(service_url, None)
class StorageServiceHandler():
_PARTITIONS_PREFIX = 'PARTITIONS'
_HOSTS_PREFIX = 'HOSTS'
def __init__(self, service_proxy: ServiceProxy) -> None:
"""Ctor for Storage handler
:param service_proxy: ServiceProxy
"""
self.service_proxy = service_proxy
def get_storage_account_key(self) -> Tuple[str, str]:
return (self.service_proxy.creds.storage_account_key,
self.service_proxy.creds.storage_account_ep)
def list_partitions(self) -> List[azure.cosmosdb.table.Entity]:
return self.service_proxy.table_client.query_entities(
self.service_proxy.table_name,
filter='PartitionKey eq \'{}${}\''.format(
self._PARTITIONS_PREFIX, self.service_proxy.cluster_id))
def get_host_assignment_entity(
self,
host: str,
) -> azure.cosmosdb.table.Entity:
return self.service_proxy.table_client.get_entity(
self.service_proxy.table_name,
'{}${}'.format(self._HOSTS_PREFIX, self.service_proxy.cluster_id),
host)
def delete_node_assignment_entity(
self,
entity: azure.cosmosdb.table.Entity,
) -> None:
try:
self.service_proxy.table_client.delete_entity(
self.service_proxy.table_name,
entity['PartitionKey'],
entity['RowKey'])
except azure.common.AzureMissingResourceHttpError:
pass
def insert_queue_assignment_msg(self, rowkey: str, host: str) -> None:
rkparts = rowkey.split('$')
qname = '{}-{}'.format(self.service_proxy.cluster_id, rkparts[1])
logger.debug('inserting host {} assignment token to queue {}'.format(
host, qname))
msg = {
'cluster_id': self.service_proxy.cluster_id,
'host': host,
}
msg_data = json.dumps(msg, ensure_ascii=True, sort_keys=True)
self.service_proxy.queue_client.put_message(
qname, msg_data, time_to_live=-1)
def get_queue_assignment_msg(self) -> None:
logger.debug('getting queue assignment from {}'.format(
self.service_proxy.queue_assign))
host = None
while host is None:
msgs = self.service_proxy.queue_client.get_messages(
self.service_proxy.queue_assign, num_messages=1,
visibility_timeout=150)
for msg in msgs:
msg_data = json.loads(msg.content, encoding='utf8')
logger.debug(
'got message {}: {}'.format(msg.id, msg_data))
host = msg_data['host']
outfile = pathlib.Path(
self.service_proxy.batch_shipyard_var_path) / 'slurm_host'
with outfile.open('wt') as f:
f.write(host)
self.service_proxy.queue_client.delete_message(
self.service_proxy.queue_assign, msg.id,
msg.pop_receipt)
break
random_blocking_sleep(1, 3)
logger.info('got host assignment: {}'.format(host))
def insert_queue_action_msg(
self,
action: Actions,
hosts: List[str],
retry_count: Optional[int] = None,
visibility_timeout: Optional[int] = None,
) -> None:
msg = {
'cluster_id': self.service_proxy.cluster_id,
'action': action.value,
'hosts': hosts,
}
if retry_count is not None:
msg['retry_count'] = retry_count
logger.debug('inserting queue {} message (vt={}): {}'.format(
self.service_proxy.queue_action, visibility_timeout, msg))
msg_data = json.dumps(msg, ensure_ascii=True, sort_keys=True)
self.service_proxy.queue_client.put_message(
self.service_proxy.queue_action, msg_data,
visibility_timeout=visibility_timeout, time_to_live=-1)
def get_queue_action_msg(
self
) -> Optional[Tuple[Dict[str, Any], str, str]]:
msgs = self.service_proxy.queue_client.get_messages(
self.service_proxy.queue_action, num_messages=1,
visibility_timeout=self.service_proxy.resume_timeout)
for msg in msgs:
msg_data = json.loads(msg.content, encoding='utf8')
logger.debug(
'got message {} from queue {}: {}'.format(
msg.id, self.service_proxy.queue_action, msg_data))
return (msg_data, msg.id, msg.pop_receipt)
return None
def update_queue_action_msg(
self,
id: str,
pop_receipt: str,
) -> None:
logger.debug(
'updating queue {} message id {} pop receipt {}'.format(
self.service_proxy.queue_action, id, pop_receipt))
self.service_proxy.queue_client.update_message(
self.service_proxy.queue_action, id, pop_receipt, 20)
def delete_queue_action_msg(
self,
id: str,
pop_receipt: str,
) -> None:
logger.debug(
'deleting queue {} message id {} pop receipt {}'.format(
self.service_proxy.queue_action, id, pop_receipt))
self.service_proxy.queue_client.delete_message(
self.service_proxy.queue_action, id, pop_receipt)
def insert_host_assignment_entity(
self,
host: str,
partition_name: str,
service_url: str,
pool_id: str
) -> None:
entity = {
'PartitionKey': '{}${}'.format(
self._HOSTS_PREFIX, self.service_proxy.cluster_id),
'RowKey': host,
'Partition': partition_name,
'State': HostStates.Resuming.value,
'BatchServiceUrl': service_url,
'BatchPoolId': pool_id,
'BatchShipyardSlurmVersion': 1,
}
self.service_proxy.table_client.insert_or_replace_entity(
self.service_proxy.table_name, entity)
def merge_host_assignment_entity_for_compute_node(
self,
host: str,
state: HostStates,
state_only: bool,
retry_on_conflict: Optional[bool] = None,
) -> None:
entity = {
'PartitionKey': '{}${}'.format(
self._HOSTS_PREFIX, self.service_proxy.cluster_id),
'RowKey': host,
'State': state.value,
}
if not state_only:
entity['BatchNodeId'] = self.service_proxy.node_id
logger.debug(
'merging host {} ip={} assignment entity in table {}: {}'.format(
host, self.service_proxy.ip_address,
self.service_proxy.table_name, entity))
if retry_on_conflict is None:
retry_on_conflict = True
while True:
try:
self.service_proxy.table_client.merge_entity(
self.service_proxy.table_name, entity)
break
except azure.common.AzureConflictHttpError:
if retry_on_conflict:
random_blocking_sleep(1, 3)
else:
raise
def update_host_assignment_entity_as_provisioned(
self,
host: str,
entity: azure.cosmosdb.table.Entity,
) -> None:
entity['State'] = HostStates.Provisioned.value
entity['IpAddress'] = self.service_proxy.ip_address
entity['BatchNodeId'] = self.service_proxy.node_id
logger.debug(
'updating host {} ip={} assignment entity in table {}: {}'.format(
host, self.service_proxy.ip_address,
self.service_proxy.table_name, entity))
# this must process sucessfully with no etag collision
self.service_proxy.table_client.update_entity(
self.service_proxy.table_name, entity)
def wait_for_host_assignment_entities(
self,
start_time: datetime.datetime,
hosts: List[str],
timeout: Optional[int] = None,
set_idle_state: Optional[bool] = None,
) -> None:
if timeout is None:
timeout = self.service_proxy.resume_timeout
logger.info('waiting for {} hosts to spin up in {} sec'.format(
len(hosts), timeout))
host_queue = collections.deque(hosts)
i = 0
while len(host_queue) > 0:
host = host_queue.popleft()
try:
entity = self.get_host_assignment_entity(host)
ip = entity['IpAddress']
state = HostStates(entity['State'])
if state != HostStates.Provisioned:
logger.error('unexpected state for host {}: {}'.format(
host, state))
raise KeyError()
except (azure.common.AzureMissingResourceHttpError, KeyError):
host_queue.append(host)
else:
logger.debug(
'updating host {} with ip {} node id {} pool id {}'.format(
host, ip, entity['BatchNodeId'],
entity['BatchPoolId']))
cmd = ['scontrol', 'update', 'NodeName={}'.format(host),
'NodeAddr={}'.format(ip),
'NodeHostname={}'.format(host)]
if set_idle_state:
cmd.append('State=Idle')
logger.debug('command: {}'.format(' '.join(cmd)))
subprocess.check_call(cmd)
# update entity state
m_entity = {
'PartitionKey': entity['PartitionKey'],
'RowKey': entity['RowKey'],
'State': HostStates.Up.value,
}
self.service_proxy.table_client.merge_entity(
self.service_proxy.table_name, m_entity)
continue
i += 1
if i % 6 == 0:
i = 0
logger.debug('still waiting for {} hosts'.format(
len(host_queue)))
diff = datetime_utcnow() - start_time
if diff.total_seconds() > timeout:
return host_queue
random_blocking_sleep(5, 10)
logger.info('{} host spin up completed'.format(len(hosts)))
return None
class BatchServiceHandler():
def __init__(self, service_proxy: ServiceProxy) -> None:
"""Ctor for Batch handler
:param service_proxy: ServiceProxy
"""
self.service_proxy = service_proxy
def get_node_state_counts(
self,
service_url: str,
pool_id: str,
) -> batchmodels.PoolNodeCounts:
client = self.service_proxy.batch_client(service_url)
try:
node_counts = client.account.list_pool_node_counts(
account_list_pool_node_counts_options=batchmodels.
AccountListPoolNodeCountsOptions(
filter='poolId eq \'{}\''.format(pool_id)
)
)
nc = list(node_counts)
if len(nc) == 0:
logger.error(
'no node counts for pool {} (service_url={})'.format(
pool_id, service_url))
return nc[0]
except batchmodels.BatchErrorException:
logger.error(
'could not retrieve pool {} node counts '
'(service_url={})'.format(pool_id, service_url))
def get_node_info(
self,
service_url: str,
pool_id: str,
node_id: str,
) -> batchmodels.ComputeNode:
auth_attempts = 0
client = self.service_proxy.batch_client(service_url)
# get pool and add delta to current target counts
while True:
try:
return client.compute_node.get(pool_id, node_id)
except batchmodels.BatchErrorException as ex:
if 'failed to authorize the request' in ex.message.value:
if auth_attempts > _MAX_AUTH_FAILURE_RETRIES:
raise
logger.warning(
'authorization failed for {}, retrying'.format(
service_url))
self.service_proxy.reset_batch_creds_and_client(
service_url)
random_blocking_sleep(1, 3)
auth_attempts += 1
client = self.service_proxy.batch_client(service_url)
else:
return None
def add_nodes_to_pool(
self,
service_url: str,
pool_id: str,
compute_node_type: str,
num_hosts: int,
) -> None:
auth_attempts = 0
client = self.service_proxy.batch_client(service_url)
# get pool and add delta to current target counts
while True:
try:
pool = client.pool.get(pool_id)
if (pool.allocation_state ==
batchmodels.AllocationState.resizing):
logger.debug(
'cannot add nodes to pool {} as it is resizing, '
'will retry'.format(pool_id))
random_blocking_sleep(5, 10)
continue
if compute_node_type == 'dedicated':
target_dedicated = pool.target_dedicated_nodes + num_hosts
target_low_priority = 0
logger.debug(
'adding dedicated nodes to pool {}: {} -> {} '
'(service_url={})'.format(
pool_id, pool.target_dedicated_nodes,
target_dedicated, service_url))
else:
target_dedicated = 0
target_low_priority = (
pool.target_low_priority_nodes + num_hosts
)
logger.debug(
'adding low priority nodes to pool {}: {} -> {} '
'(service_url={})'.format(
pool_id, pool.target_low_priority_nodes,
target_low_priority, service_url))
client.pool.resize(
pool_id,
pool_resize_parameter=batchmodels.PoolResizeParameter(
target_dedicated_nodes=target_dedicated,
target_low_priority_nodes=target_low_priority,
),
)
logger.info('added nodes to pool {}'.format(pool_id))
break
except batchmodels.BatchErrorException as ex:
if 'ongoing resize operation' in ex.message.value:
logger.debug('pool {} is resizing, will retry'.format(
pool_id))
random_blocking_sleep(5, 10)
elif 'failed to authorize the request' in ex.message.value:
if auth_attempts > _MAX_AUTH_FAILURE_RETRIES:
raise
logger.warning(
'authorization failed for {}, retrying'.format(
service_url))
self.service_proxy.reset_batch_creds_and_client(
service_url)
random_blocking_sleep(1, 3)
auth_attempts += 1
client = self.service_proxy.batch_client(service_url)
else:
logger.exception(
'could not add nodes to pool {} '
'(service_url={})'.format(pool_id, service_url))
raise
def remove_nodes_from_pool(
self,
service_url: str,
pool_id: str,
nodes: List[str],
) -> None:
auth_attempts = 0
client = self.service_proxy.batch_client(service_url)
while True:
try:
pool = client.pool.get(pool_id)
if (pool.allocation_state ==
batchmodels.AllocationState.resizing):
logger.debug(
'cannot remove nodes to pool {} as it is resizing, '
'will retry'.format(pool_id))
random_blocking_sleep(5, 10)
continue
client.pool.remove_nodes(
pool_id,
node_remove_parameter=batchmodels.NodeRemoveParameter(
node_list=nodes,
),
)
logger.info('removed {} nodes from pool {}'.format(
len(nodes), pool_id))
break
except batchmodels.BatchErrorException as ex:
if 'ongoing resize operation' in ex.message.value:
logger.debug('pool {} has ongoing resize operation'.format(
pool_id))
random_blocking_sleep(5, 10)
elif 'failed to authorize the request' in ex.message.value:
if auth_attempts > _MAX_AUTH_FAILURE_RETRIES:
# TODO need better recovery - requeue suspend action?
raise
logger.warning(
'authorization failed for {}, retrying'.format(
service_url))
self.service_proxy.reset_batch_creds_and_client(
service_url)
random_blocking_sleep(1, 3)
auth_attempts += 1
client = self.service_proxy.batch_client(service_url)
else:
logger.error(
'could not remove nodes from pool {} '
'(service_url={})'.format(pool_id, service_url))
# TODO need better recovery - requeue suspend action?
break
# delete log files
for node in nodes:
file = pathlib.Path(
self.service_proxy.file_share_hmp
) / 'slurm' / 'logs' / 'slurm-helper-debug-{}.log'.format(node)
try:
file.unlink()
except OSError:
pass
def clean_pool(
self,
service_url: str,
pool_id: str,
) -> None:
auth_attempts = 0
node_filter = [
'(state eq \'starttaskfailed\')',
'(state eq \'unusable\')',
'(state eq \'preempted\')',
]
client = self.service_proxy.batch_client(service_url)
while True:
try:
nodes = client.compute_node.list(
pool_id=pool_id,
compute_node_list_options=batchmodels.
ComputeNodeListOptions(filter=' or '.join(node_filter)),
)
node_ids = [node.id for node in nodes]
if is_none_or_empty(node_ids):
logger.debug('no nodes to clean from pool: {}'.format(
pool_id))
return
logger.info('removing nodes {} from pool {}'.format(
node_ids, pool_id))
client.pool.remove_nodes(
pool_id=pool_id,
node_remove_parameter=batchmodels.NodeRemoveParameter(
node_list=node_ids,
)
)
break
except batchmodels.BatchErrorException as ex:
if 'ongoing resize operation' in ex.message.value:
logger.debug('pool {} has ongoing resize operation'.format(
pool_id))
random_blocking_sleep(5, 10)
elif 'failed to authorize the request' in ex.message.value:
if auth_attempts > _MAX_AUTH_FAILURE_RETRIES:
# TODO need better recovery - requeue suspend action?
raise
logger.warning(
'authorization failed for {}, retrying'.format(
service_url))
self.service_proxy.reset_batch_creds_and_client(
service_url)
random_blocking_sleep(1, 3)
auth_attempts += 1
client = self.service_proxy.batch_client(service_url)
else:
logger.error(
'could not clean pool {} (service_url={})'.format(
pool_id, service_url))
# TODO need better recovery - requeue resume fail action?
break
class CommandProcessor():
def __init__(self, config: Dict[str, Any]) -> None:
"""Ctor for CommandProcessor
:param config: configuration
"""
self._service_proxy = ServiceProxy(config)
self._partitions = None
self.ssh = StorageServiceHandler(self._service_proxy)
self.bsh = BatchServiceHandler(self._service_proxy)
@property
def slurm_partitions(self) -> List[azure.cosmosdb.table.Entity]:
if self._partitions is None:
self._partitions = list(self.ssh.list_partitions())
for entity in self._partitions:
entity['HostList'] = re.compile(entity['HostList'])
return self._partitions
def set_log_configuration(self) -> None:
global logger
# remove existing handlers
handlers = logger.handlers[:]
for handler in handlers:
handler.close()
logger.removeHandler(handler)
# set level
logger.setLevel(logging.DEBUG)
# set formatter
formatter = logging.Formatter(
'%(asctime)s %(levelname)s %(name)s:%(funcName)s:%(lineno)d '
'%(message)s')
formatter.default_msec_format = '%s.%03d'
# set handlers
handler_stream = logging.StreamHandler()
handler_stream.setFormatter(formatter)
logger.addHandler(handler_stream)
az_storage_logger = logging.getLogger('azure.storage')
az_storage_logger.setLevel(logging.WARNING)
az_storage_logger.addHandler(handler_stream)
az_cosmosdb_logger = logging.getLogger('azure.cosmosdb')
az_cosmosdb_logger.setLevel(logging.WARNING)
az_cosmosdb_logger.addHandler(handler_stream)
# log to selected log level file
logfname = pathlib.Path('slurm-helper.log')
logfile = (
self._service_proxy.file_share_hmp / 'slurm' / 'logs' /
('{}-{}-{}{}').format(
logfname.stem, 'debug', self._service_proxy.logging_id,
logfname.suffix)
)
logfile.parent.mkdir(exist_ok=True)
handler_logfile = logging.handlers.RotatingFileHandler(
str(logfile), maxBytes=33554432, backupCount=20000,
encoding='utf-8')
handler_logfile.setFormatter(formatter)
logger.addHandler(handler_logfile)
az_storage_logger.addHandler(handler_logfile)
az_cosmosdb_logger.addHandler(handler_logfile)
# dump configuration
self._service_proxy.log_configuration()
def process_resume_action(self, hosts: List[str]) -> None:
if len(hosts) == 0:
logger.error('host list is empty for resume')
return
logger.debug(
'pulled action resume for hosts: {}'.format(', '.join(hosts)))
# first check if hosts have already resumed as this action can
# be called multiple times for the same set of hosts due to
# controller failover
hosts_modified = []
for he in hosts:
host, partname = he.split()
try:
entity = self.ssh.get_host_assignment_entity(host)
state = HostStates(entity['State'])
except azure.common.AzureMissingResourceHttpError:
hosts_modified.append(he)
else:
logger.debug(
'host entry {} found for partition {} but state '
'is {}'.format(host, partname, state))
if state != HostStates.Resuming and state != HostStates.Up:
hosts_modified.append(he)
hosts = hosts_modified
del hosts_modified
logger.debug('resuming hosts: {}'.format(', '.join(hosts)))
if len(hosts) == 0:
logger.error('modified host list is empty for resume')
return
# collate hosts into a map of batch pools -> host config
pool_map = {}
partitions = self.slurm_partitions
for entity in partitions:
key = '{}${}'.format(
entity['BatchServiceUrl'], entity['BatchPoolId'])
pool_map[key] = {
'num_hosts': 0,
'compute_node_type': entity['ComputeNodeType'],
}
# insert host assignment entity and message
total_hosts = 0
for he in hosts:
host, partname = he.split()
for entity in partitions:
if not entity['RowKey'].startswith(partname):
continue
if entity['HostList'].fullmatch(host):
key = '{}${}'.format(
entity['BatchServiceUrl'], entity['BatchPoolId'])
pool_map[key]['num_hosts'] += 1
self.ssh.insert_host_assignment_entity(
host, partname, entity['BatchServiceUrl'],
entity['BatchPoolId'])
self.ssh.insert_queue_assignment_msg(
entity['RowKey'], host)
total_hosts += 1
break
if total_hosts != len(hosts):
logger.error(
'total host {} to number of hosts to resume '
'{} mismatch'.format(total_hosts, len(hosts)))
del total_hosts
# resize batch pools to specified number of hosts
resize_futures = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers_for_executor(pool_map)) as executor:
for key in pool_map:
service_url, pool_id = key.split('$')
resize_futures[key] = executor.submit(
self.bsh.add_nodes_to_pool(
service_url, pool_id,
pool_map[key]['compute_node_type'],
pool_map[key]['num_hosts'])
)
def process_suspend_action(self, hosts: List[str]) -> bool:
if len(hosts) == 0:
logger.error('host list is empty for suspend')
return True
logger.debug('suspending hosts: {}'.format(', '.join(hosts)))
# find pool/account mapping for node
suspend_retry = set()
suspended = []
pool_map = {}
entities = []
for host in hosts:
try:
entity = self.ssh.get_host_assignment_entity(host)
except azure.common.AzureMissingResourceHttpError:
logger.error('host {} entity not found'.format(host))
continue
logger.debug('found host {} mapping: {}'.format(host, entity))
if HostStates(entity['State']) == HostStates.Suspended:
logger.error('host {} is already suspended'.format(host))
continue
try:
node_id = entity['BatchNodeId']
except KeyError:
logger.error(
'host {} does not have a batch node id assigned'.format(
host))
suspend_retry.add(host)
continue
key = '{}${}'.format(
entity['BatchServiceUrl'], entity['BatchPoolId'])
if key not in pool_map:
pool_map[key] = []
pool_map[key].append(node_id)
suspended.append(host)
entities.append(entity)
# resize batch pools down, deleting specified hosts
if len(entities) == 0:
logger.info('no hosts to suspend after analyzing host entities')
else:
# remove nodes
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers_for_executor(
pool_map)) as executor:
for key in pool_map:
service_url, pool_id = key.split('$')
executor.submit(
self.bsh.remove_nodes_from_pool(
service_url, pool_id, pool_map[key])
)
# mark entities suspended
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers_for_executor(
suspended)) as executor:
for host in suspended:
executor.submit(
self.ssh.
merge_host_assignment_entity_for_compute_node(
host, HostStates.Suspended, True,
retry_on_conflict=True)
)
# delete log files
for host in suspended:
file = pathlib.Path(
self._service_proxy.file_share_hmp
) / 'slurm' / 'logs' / 'slurmd-{}.log'.format(host)
try:
file.unlink()
except OSError:
pass
# re-enqueue suspend retry entries
if len(suspend_retry) > 0:
if set(hosts) == suspend_retry:
logger.debug('host suspend list in is the same as retry')
return False
logger.debug('adding suspend action for {} hosts to retry'.format(
len(suspend_retry)))
self.ssh.insert_queue_action_msg(
Actions.Suspend, list(suspend_retry))
return True
def _query_node_state(self, host, entity):
service_url = entity['BatchServiceUrl']
pool_id = entity['BatchPoolId']
node_id = None
try:
node_id = entity['BatchNodeId']
except KeyError:
logger.debug('batch node id not present for host {}'.format(
host))
return node_id, None
node = self.bsh.get_node_info(service_url, pool_id, node_id)
if node is None:
logger.error(
'host {} compute node {} on pool {} (service_url={}) '
'does not exist'.format(
host, node_id, pool_id, service_url))
return node_id, None
logger.debug(
'host node {} on pool {} is {} (service_url={})'.format(
host, node_id, pool_id, node.state, service_url))
return node_id, node.state
def process_resume_failed_action(
self,
hosts: List[str],
retry_count: int,
) -> bool:
if len(hosts) == 0:
logger.error('host list is empty for resume failed')
return True
hosts_retry = set()
hosts_update = set()
hosts_check = {}
clean_pools = set()
# create pool map from partitions
pool_map = {}
partitions = self.slurm_partitions
for entity in partitions:
key = '{}${}'.format(
entity['BatchServiceUrl'], entity['BatchPoolId'])
pool_map[key] = {
'compute_node_type': entity['ComputeNodeType'],
'hosts_recover': set(),
'nodes_recover': set(),
}
# check host state
for he in hosts:
host, partname = he.split()
try:
entity = self.ssh.get_host_assignment_entity(host)
host_state = HostStates(entity['State'])
except azure.common.AzureMissingResourceHttpError:
# TODO what else can we do here?
logger.error('host {} entity not found'.format(host))
continue
logger.debug('host {} state is {}'.format(host, host_state))
if host_state == HostStates.Up:
# TODO if up, verify sinfo state
# sinfo -h -n host -o "%t"
hosts_update.add(host)
else:
hosts_check[he] = entity
# check pool for each host to mark cleanup
for he in hosts:
host, partname = he.split()
entity = hosts_check[he]
service_url = entity['BatchServiceUrl']
pool_id = entity['BatchPoolId']
node_counts = self.bsh.get_node_state_counts(service_url, pool_id)
num_bad_nodes = (
node_counts.dedicated.unusable +
node_counts.dedicated.start_task_failed +
node_counts.low_priority.unusable +
node_counts.low_priority.start_task_failed
)
# mark pool for cleanup for any unusable/start task failed
if num_bad_nodes > 0:
key = '{}${}'.format(service_url, pool_id)
logger.debug('{} bad nodes found on {}'.format(
num_bad_nodes, key))
clean_pools.add(key)
# check each host on the Batch service
for he in hosts:
host, partname = he.split()
entity = hosts_check[he]
service_url = entity['BatchServiceUrl']
pool_id = entity['BatchPoolId']
key = '{}${}'.format(service_url, pool_id)
node_id, node_state = self._query_node_state(host, entity)
if (node_state == batchmodels.ComputeNodeState.idle or
node_state == batchmodels.ComputeNodeState.offline or
node_state == batchmodels.ComputeNodeState.running):
hosts_update.add(host)
elif (node_state ==
batchmodels.ComputeNodeState.start_task_failed or
node_state == batchmodels.ComputeNodeState.unusable or
node_state == batchmodels.ComputeNodeState.preempted or
node_state == batchmodels.ComputeNodeState.leaving_pool):
clean_pools.add(key)
pool_map[key]['nodes_recover'].add(node_id)
pool_map[key]['hosts_recover'].add(he)
else:
if retry_count >= _MAX_RESUME_FAILURE_ATTEMPTS:
logger.debug(
'{} on partition {} exceeded max resume failure retry '
'attempts, recovering instead'.format(host, partname))
if node_id is not None:
pool_map[key]['nodes_recover'].add(node_id)
pool_map[key]['hosts_recover'].add(he)
else:
hosts_retry.add(he)
del hosts_check
# update hosts
if len(hosts_update) > 0:
self.ssh.insert_queue_action_msg(
Actions.WaitForResume, list(hosts_update), retry_count=0)
del hosts_update
# clean pools
for key in clean_pools:
service_url, pool_id = key.split('$')
self.bsh.clean_pool(service_url, pool_id)
del clean_pools
# recover hosts
for key in pool_map:
hosts_recover = pool_map[key]['hosts_recover']
hrlen = len(hosts_recover)
if hrlen == 0:
continue
for he in hosts_recover:
host, partname = he.split()
self.ssh.merge_host_assignment_entity_for_compute_node(
host, HostStates.ProvisionInterrupt, True,
retry_on_conflict=True)
nodes_recover = pool_map[key]['nodes_recover']
if len(nodes_recover) > 0:
service_url, pool_id = key.split('$')
self.bsh.remove_nodes_from_pool(
service_url, pool_id, list(nodes_recover))
host_list = list(hosts_recover)
self.ssh.insert_queue_action_msg(Actions.Resume, host_list)
self.ssh.insert_queue_action_msg(
Actions.WaitForResume, host_list, retry_count=0,
visibility_timeout=60)
# re-enqueue failed resume hosts
if len(hosts_retry) > 0:
logger.debug(
'adding resume failed action for {} hosts to retry'.format(
len(hosts_retry)))
self.ssh.insert_queue_action_msg(
Actions.ResumeFailed, list(hosts_retry),
retry_count=retry_count + 1, visibility_timeout=60)
return True
def process_wait_for_resume_action(
self,
hosts: List[str],
retry_count: int,
) -> bool:
if len(hosts) == 0:
logger.error('host list is empty for resume failed')
return True
start_time = datetime_utcnow()
remain_hosts = self.ssh.wait_for_host_assignment_entities(
start_time, hosts, timeout=10, set_idle_state=True)
if remain_hosts is not None:
if retry_count > self._service_proxy.resume_timeout / 30:
logger.error(
'not retrying host spin up completion for: {}'.format(
remain_hosts))
self.ssh.insert_queue_action_msg(
Actions.ResumeFailed, list(remain_hosts), retry_count=0,
visibility_timeout=5)
else:
logger.warning(
'host spin up not completed: {}'.format(remain_hosts))
self.ssh.insert_queue_action_msg(
Actions.WaitForResume, list(remain_hosts),
retry_count=retry_count + 1, visibility_timeout=30)
return True
def resume_hosts(self, hosts: List[str]) -> None:
# insert into action queue
start_time = datetime_utcnow()
logger.debug('received resume hosts: {}'.format(', '.join(hosts)))
self.ssh.insert_queue_action_msg(Actions.Resume, hosts)
# process resume completions and translate into scontrol
bare_hosts = [he.split()[0] for he in hosts]
remain_hosts = self.ssh.wait_for_host_assignment_entities(
start_time, bare_hosts)
if remain_hosts is not None:
raise RuntimeError(
'exceeded resume timeout waiting for hosts to '
'spin up: {}'.format(remain_hosts))
def resume_hosts_failed(self, hosts: List[str]) -> None:
# insert into action queue
logger.info('received resume failed hosts: {}'.format(
', '.join(hosts)))
self.ssh.insert_queue_action_msg(
Actions.ResumeFailed, hosts, retry_count=0, visibility_timeout=5)
def suspend_hosts(self, hosts: List[str]) -> None:
# insert into action queue
logger.debug('received suspend hosts: {}'.format(', '.join(hosts)))
self.ssh.insert_queue_action_msg(Actions.Suspend, hosts)
def check_provisioning_status(self, host: str) -> None:
logger.debug(
'checking for provisioning status for host {}'.format(host))
try:
entity = self.ssh.get_host_assignment_entity(host)
state = HostStates(entity['State'])
except (azure.common.AzureMissingResourceHttpError, KeyError):
# this should not happen, but fail in case it does
logger.error('host assignment entity does not exist for {}'.format(
host))
sys.exit(1)
logger.info('host {} state property is {}'.format(host, state))
if state != HostStates.Resuming:
logger.error(
'unexpected state, state is not {} for host {}'.format(
HostStates.Resuming, host))
# update host entity assignment
self.ssh.merge_host_assignment_entity_for_compute_node(
host, HostStates.ProvisionInterrupt, False,
retry_on_conflict=False)
sys.exit(1)
return entity
def daemon_processor(self) -> None:
# set logging config for daemon processor
self.set_log_configuration()
logger.info('daemon processor starting')
while True:
msg = self.ssh.get_queue_action_msg()
if msg is None:
random_blocking_sleep(1, 3)
else:
del_msg = True
action = msg[0]['action']
hosts = msg[0]['hosts']
msg_id = msg[1]
pop_receipt = msg[2]
if action == Actions.Suspend:
del_msg = self.process_suspend_action(hosts)
elif action == Actions.Resume:
self.process_resume_action(hosts)
elif action == Actions.ResumeFailed:
del_msg = self.process_resume_failed_action(
hosts, msg[0]['retry_count'])
elif action == Actions.WaitForResume:
del_msg = self.process_wait_for_resume_action(
hosts, msg[0]['retry_count'])
else:
logger.error('unknown action {} for hosts {}'.format(
action, ', '.join(hosts)))
if del_msg:
self.ssh.delete_queue_action_msg(msg_id, pop_receipt)
else:
self.ssh.update_queue_action_msg(msg_id, pop_receipt)
def execute(
self,
action: str,
hosts: Optional[List[str]],
host: Optional[str],
) -> None:
"""Execute action
:param action: action to execute
"""
# process actions
if action == 'daemon':
self.daemon_processor()
elif action == 'sakey':
sakey = self.ssh.get_storage_account_key()
print(sakey[0], sakey[1])
elif action == 'resume':
self.resume_hosts(hosts)
elif action == 'resume-fail':
self.resume_hosts_failed(hosts)
elif action == 'suspend':
self.suspend_hosts(hosts)
elif action == 'check-provisioning-status':
self.check_provisioning_status(host)
elif action == 'get-node-assignment':
self.ssh.get_queue_assignment_msg()
elif action == 'complete-node-assignment':
entity = self.check_provisioning_status(host)
self.ssh.update_host_assignment_entity_as_provisioned(host, entity)
else:
raise ValueError('unknown action to execute: {}'.format(action))
def main() -> None:
"""Main function"""
# get command-line args
args = parseargs()
if is_none_or_empty(args.action):
raise ValueError('action is invalid')
# load configuration
if is_none_or_empty(args.conf):
raise ValueError('config file not specified')
with open(args.conf, 'rb') as f:
config = json.load(f)
logger.debug('loaded config from {}: {}'.format(args.conf, config))
# parse hostfile
if args.hostfile is not None:
with open(args.hostfile, 'r') as f:
hosts = [line.rstrip() for line in f]
else:
hosts = None
try:
# create command processor
cmd_processor = CommandProcessor(config)
# execute action
cmd_processor.execute(args.action, hosts, args.host)
except Exception:
logger.exception('error executing {}'.format(args.action))
finally:
handlers = logger.handlers[:]
for handler in handlers:
handler.close()
logger.removeHandler(handler)
def parseargs() -> argparse.Namespace:
"""Parse program arguments
:return: parsed arguments
"""
parser = argparse.ArgumentParser(
description='slurm: Azure Batch Shipyard Slurm Helper')
parser.add_argument(
'action',
choices=[
'daemon', 'sakey', 'resume', 'resume-fail', 'suspend',
'check-provisioning-status', 'get-node-assignment',
'complete-node-assignment',
]
)
parser.add_argument('--conf', help='configuration file')
parser.add_argument('--hostfile', help='host file')
parser.add_argument('--host', help='host')
return parser.parse_args()
if __name__ == '__main__':
# set up log formatting and default handlers
setup_logger(logger)
main()