superbenchmark/superbench/runner/ansible.py

135 строки
4.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""SuperBench Ansible Client."""
import tempfile
from pathlib import Path
import ansible_runner
from ansible.parsing.dataloader import DataLoader
from ansible.inventory.manager import InventoryManager
from superbench.common.utils import logger
class AnsibleClient():
"""Ansible Client class."""
def __init__(self, config):
"""Initilize.
Args:
config (DictConfig): Ansible config object.
"""
self._playbook_path = Path(__file__).parent / 'playbooks'
self._config = {
'host_pattern': 'localhost',
'cmdline': '--forks 128',
}
self._head_host = None
self.failure_count = 0
if config:
inventory_file = getattr(config, 'host_file', None)
inventory_list = getattr(config, 'host_list', None)
if inventory_list:
inventory_list = inventory_list.strip(',')
if inventory_file or inventory_list:
self._config['host_pattern'] = 'all'
inventory = InventoryManager(loader=DataLoader(), sources=inventory_file or f'{inventory_list},')
host_list = inventory.get_hosts(pattern='all', order='sorted')
if len(host_list) > 0:
self._config['cmdline'] = '--forks {}'.format(len(host_list))
self._head_host = host_list[0].get_name()
if inventory_list in ['localhost', '127.0.0.1']:
self._config['cmdline'] += ' --connection local'
self._config['cmdline'] += ' --inventory {}'.format(inventory_file or f'{inventory_list},')
username = getattr(config, 'host_username', None)
if username:
self._config['cmdline'] += ' --user {}'.format(username)
password = getattr(config, 'host_password', None)
if password:
self._config['passwords'] = {
'password': password,
'passphrase': password,
}
key_file = getattr(config, 'private_key', None)
if key_file:
self._config['cmdline'] += ' --private-key {}'.format(key_file)
elif password:
self._config['cmdline'] += ' --ask-pass --ask-become-pass'
logger.info(self._config)
def run(self, ansible_config, cancel_callback=None, sudo=False): # pragma: no cover
"""Run Ansible runner.
Args:
ansible_config (dict): Ansible config dict.
cancel_callback (Callable): Ansible runner cancel callback.
sudo (bool): Run as sudo or not. Defaults to False.
Returns:
int: Ansible return code.
"""
if sudo:
logger.info('Run as sudo ...')
ansible_config['cmdline'] += ' --become'
with tempfile.TemporaryDirectory(prefix='ansible') as tmpdir:
r = ansible_runner.run(private_data_dir=tmpdir, cancel_callback=cancel_callback, **ansible_config)
logger.debug(r.stats)
if r.rc == 0:
logger.info('Run succeed, return code {}.'.format(r.rc))
else:
self.failure_count += 1
logger.warning('Run failed, return code {}.'.format(r.rc))
return r.rc
def update_mpi_config(self, ansible_config):
"""Update ansible config for mpi, run on the first host of inventory group.
Args:
ansible_config (dict): Ansible config dict.
Returns:
dict: Updated Ansible config dict.
"""
if not self._head_host:
ansible_config['host_pattern'] += '[0]'
else:
ansible_config['host_pattern'] = self._head_host
return ansible_config
def get_shell_config(self, cmd):
"""Get ansible config for shell module.
Args:
cmd (str): Shell command for config.
Returns:
dict: Ansible config dict.
"""
logger.info('Run {} on remote ...'.format(cmd))
ansible_config = {
**self._config,
'module': 'shell',
'module_args': cmd,
}
return ansible_config
def get_playbook_config(self, playbook, extravars=None):
"""Get ansible config for playbook.
Args:
playbook (str): Playbook file name.
extravars (dict): Extra variables in playbook. Defaults to None.
Returns:
dict: Ansible config dict.
"""
logger.info('Run playbook {} ...'.format(playbook))
ansible_config = {
**self._config,
'extravars': extravars,
'playbook': str(self._playbook_path / playbook),
}
return ansible_config