aztk/aztk_cli/config.py

327 строки
12 KiB
Python

import os
import yaml
import aztk.spark
from aztk.spark.models import (
SecretsConfiguration,
ClusterConfiguration,
SchedulingTarget,
)
from aztk.utils import deprecate
from aztk.models import Toolkit
from aztk.models.plugins.internal import PluginReference
def load_aztk_secrets() -> SecretsConfiguration:
"""
Loads aztk from .aztk/secrets.yaml files(local and global)
"""
secrets = SecretsConfiguration()
# read global ~/secrets.yaml
global_config = _load_config_file(os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, '.aztk', 'secrets.yaml'))
# read current working directory secrets.yaml
local_config = _load_config_file(aztk.utils.constants.DEFAULT_SECRETS_PATH)
if not global_config and not local_config:
raise aztk.error.AztkError("There is no secrets.yaml in either ./.aztk/secrets.yaml or .aztk/secrets.yaml")
if global_config: # Global config is optional
_merge_secrets_dict(secrets, global_config)
if local_config:
_merge_secrets_dict(secrets, local_config)
# Validate and raise error if any
secrets.validate()
return secrets
def _load_config_file(path: str):
if not os.path.isfile(path):
return None
with open(path, 'r', encoding='UTF-8') as stream:
try:
return yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError(
"Error in {0}:\n {1}".format(path, err))
def _merge_secrets_dict(secrets: SecretsConfiguration, secrets_config):
if 'default' in secrets_config:
deprecate("default key in secrets.yaml is deprecated.", "Place all child parameters directly at the root")
secrets_config = dict(**secrets_config, **secrets_config.pop('default'))
other = SecretsConfiguration.from_dict(secrets_config)
secrets.merge(other)
def read_cluster_config(
path: str = aztk.utils.constants.DEFAULT_CLUSTER_CONFIG_PATH
) -> ClusterConfiguration:
"""
Reads the config file in the .aztk/ directory (.aztk/cluster.yaml)
"""
config_dict = _load_config_file(path)
return cluster_config_from_dict(config_dict)
def cluster_config_from_dict(config: dict):
wait = False
if config.get('plugins') not in [[None], None]:
plugins = []
for plugin in config['plugins']:
ref = PluginReference.from_dict(plugin)
plugins.append(ref.get_plugin())
config["plugins"] = plugins
if config.get('username') is not None:
config['user_configuration'] = dict(username=config.pop('username'))
if config.get('wait') is not None:
wait = config.pop('wait')
return ClusterConfiguration.from_dict(config), wait
class SshConfig:
def __init__(self):
self.username = None
self.cluster_id = None
self.host = False
self.connect = True
self.internal = False
# Set up ports with default values
self.job_ui_port = '4040'
self.job_history_ui_port = '18080'
self.web_ui_port = '8080'
def _read_config_file(
self, path: str = aztk.utils.constants.DEFAULT_SSH_CONFIG_PATH):
"""
Reads the config file in the .aztk/ directory (.aztk/ssh.yaml)
"""
if not os.path.isfile(path):
return
with open(path, 'r', encoding='UTF-8') as stream:
try:
config = yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError(
"Error in ssh.yaml: {0}".format(err))
if config is None:
return
self._merge_dict(config)
def _merge_dict(self, config):
if config.get('username') is not None:
self.username = config['username']
if config.get('cluster_id') is not None:
self.cluster_id = config['cluster_id']
if config.get('job_ui_port') is not None:
self.job_ui_port = config['job_ui_port']
if config.get('job_history_ui_port') is not None:
self.job_history_ui_port = config['job_history_ui_port']
if config.get('web_ui_port') is not None:
self.web_ui_port = config['web_ui_port']
if config.get('host') is not None:
self.host = config['host']
if config.get('connect') is not None:
self.connect = config['connect']
if config.get('internal') is not None:
self.internal = config['internal']
def merge(self, cluster_id, username, job_ui_port, job_history_ui_port,
web_ui_port, host, connect, internal):
"""
Merges fields with args object
"""
self._read_config_file(
os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, '.aztk',
'ssh.yaml'))
self._read_config_file()
self._merge_dict(
dict(
cluster_id=cluster_id,
username=username,
job_ui_port=job_ui_port,
job_history_ui_port=job_history_ui_port,
web_ui_port=web_ui_port,
host=host,
connect=connect,
internal=internal))
if self.cluster_id is None:
raise aztk.error.AztkError(
"Please supply an id for the cluster either in the ssh.yaml configuration file or with a parameter (--id)"
)
if self.username is None:
raise aztk.error.AztkError(
"Please supply a username either in the ssh.yaml configuration file or with a parameter (--username)"
)
class JobConfig():
def __init__(self):
self.id = None
self.applications = []
self.custom_scripts = None
self.spark_configuration = None
self.vm_size = None
self.toolkit = None
self.max_dedicated_nodes = 0
self.max_low_pri_nodes = 0
self.spark_defaults_conf = None
self.spark_env_sh = None
self.core_site_xml = None
self.subnet_id = None
self.worker_on_master = None
self.scheduling_target = None
def _merge_dict(self, config):
config = config.get('job')
if config.get('id') is not None:
self.id = config['id']
cluster_configuration = config.get('cluster_configuration')
if cluster_configuration:
self.vm_size = cluster_configuration.get('vm_size')
self.toolkit = Toolkit.from_dict(cluster_configuration.get('toolkit'))
if cluster_configuration.get('size') is not None:
self.max_dedicated_nodes = cluster_configuration.get('size')
if cluster_configuration.get('size_low_priority') is not None:
self.max_low_pri_nodes = cluster_configuration.get('size_low_priority')
self.custom_scripts = cluster_configuration.get('custom_scripts')
self.subnet_id = cluster_configuration.get('subnet_id')
self.worker_on_master = cluster_configuration.get("worker_on_master")
scheduling_target = cluster_configuration.get("scheduling_target")
if scheduling_target:
self.scheduling_target = SchedulingTarget(scheduling_target)
applications = config.get('applications')
if applications:
self.applications = []
for application in applications:
self.applications.append(
aztk.spark.models.ApplicationConfiguration(
name=application.get('name'),
application=application.get('application'),
application_args=application.get('application_args'),
main_class=application.get('main_class'),
jars=application.get('jars'),
py_files=application.get('py_files'),
files=application.get('files'),
driver_java_options=application.get('driver_java_options'),
driver_library_path=application.get('driver_library_path'),
driver_class_path=application.get('driver_class_path'),
driver_memory=application.get('driver_memory'),
executor_memory=application.get('executor_memory'),
driver_cores=application.get('driver_cores'),
executor_cores=application.get('executor_cores')
)
)
spark_configuration = config.get('spark_configuration')
if spark_configuration:
self.spark_defaults_conf = self.__convert_to_path(spark_configuration.get('spark_defaults_conf'))
self.spark_env_sh = self.__convert_to_path(spark_configuration.get('spark_env_sh'))
self.core_site_xml = self.__convert_to_path(spark_configuration.get('core_site_xml'))
self.jars = [self.__convert_to_path(jar) for jar in spark_configuration.get('jars') or []]
def __convert_to_path(self, str_path):
if str_path:
abs_path = os.path.abspath(os.path.expanduser(str_path))
if not os.path.exists(abs_path):
raise aztk.error.AztkError(
"Could not find file: {0}\nCheck your configuration file".
format(str_path))
return abs_path
def _read_config_file(
self, path: str = aztk.utils.constants.DEFAULT_SPARK_JOB_CONFIG):
"""
Reads the Job config file in the .aztk/ directory (.aztk/job.yaml)
"""
if not path or not os.path.isfile(path):
return
with open(path, 'r', encoding='UTF-8') as stream:
try:
config = yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError(
"Error in job.yaml: {0}".format(err))
if config is None:
return
self._merge_dict(config)
def merge(self, id, job_config_yaml=None):
self._read_config_file(aztk.utils.constants.GLOBAL_SPARK_JOB_CONFIG)
self._read_config_file(aztk.utils.constants.DEFAULT_SPARK_JOB_CONFIG)
self._read_config_file(job_config_yaml)
if id:
self.id = id
for entry in self.applications:
if entry.name is None:
raise aztk.error.AztkError(
"Application specified with no name. Please verify your configuration in job.yaml")
if entry.application is None:
raise aztk.error.AztkError(
"No path to application specified for {} in job.yaml".format(entry.name))
def get_file_if_exists(file):
local_conf_file = os.path.join(
aztk.utils.constants.DEFAULT_SPARK_CONF_SOURCE, file)
global_conf_file = os.path.join(aztk.utils.constants.GLOBAL_CONFIG_PATH,
file)
if os.path.exists(local_conf_file):
return local_conf_file
if os.path.exists(global_conf_file):
return global_conf_file
return None
def load_aztk_spark_config():
return aztk.spark.models.SparkConfiguration(
spark_defaults_conf=get_file_if_exists('spark-defaults.conf'),
jars=load_jars(),
spark_env_sh=get_file_if_exists('spark-env.sh'),
core_site_xml=get_file_if_exists('core-site.xml'))
def load_jars():
jars = None
# try load global
try:
jars_src = os.path.join(aztk.utils.constants.GLOBAL_CONFIG_PATH,
'jars')
jars = [os.path.join(jars_src, jar) for jar in os.listdir(jars_src)]
except FileNotFoundError:
pass
# try load local, overwrite if found
try:
jars_src = os.path.join(aztk.utils.constants.DEFAULT_SPARK_CONF_SOURCE,
'jars')
jars = [os.path.join(jars_src, jar) for jar in os.listdir(jars_src)]
except FileNotFoundError:
pass
return jars