aztk/aztk_cli/config.py

306 строки
11 KiB
Python

import os
import yaml
import aztk.spark
from aztk.models import Toolkit
from aztk.models.plugins.internal import PluginReference
from aztk.spark.models import (ClusterConfiguration, SchedulingTarget, SecretsConfiguration)
def load_aztk_secrets() -> SecretsConfiguration:
"""
Loads aztk from .aztk/secrets.yaml files(local and global)
"""
secrets = SecretsConfiguration()
# read global ~/secrets.yaml
global_config = _load_config_file(os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, ".aztk", "secrets.yaml"))
# read current working directory secrets.yaml
local_config = _load_config_file(aztk.utils.constants.DEFAULT_SECRETS_PATH)
if not global_config and not local_config:
raise aztk.error.AztkError("There is no secrets.yaml in either ./.aztk/secrets.yaml or .aztk/secrets.yaml")
if global_config: # Global config is optional
_merge_secrets_dict(secrets, global_config)
if local_config:
_merge_secrets_dict(secrets, local_config)
# Validate and raise error if any
secrets.validate()
return secrets
def _load_config_file(path: str):
if not os.path.isfile(path):
return None
with open(path, "r", encoding="UTF-8") as stream:
try:
return yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError("Error in {0}:\n {1}".format(path, err))
def _merge_secrets_dict(secrets: SecretsConfiguration, secrets_config):
other = SecretsConfiguration.from_dict(secrets_config)
secrets.merge(other)
def read_cluster_config(path: str = aztk.utils.constants.DEFAULT_CLUSTER_CONFIG_PATH) -> ClusterConfiguration:
"""
Reads the config file in the .aztk/ directory (.aztk/cluster.yaml)
"""
config_dict = _load_config_file(path)
return cluster_config_from_dict(config_dict)
def cluster_config_from_dict(config: dict):
wait = False
if config.get("plugins") not in [[None], None]:
plugins = []
for plugin in config["plugins"]:
ref = PluginReference.from_dict(plugin)
plugins.append(ref.get_plugin())
config["plugins"] = plugins
if config.get("username") is not None:
config["user_configuration"] = dict(username=config.pop("username"))
if config.get("wait") is not None:
wait = config.pop("wait")
return ClusterConfiguration.from_dict(config), wait
class SshConfig:
def __init__(self):
self.username = None
self.cluster_id = None
self.host = False
self.connect = True
self.internal = False
# Set up ports with default values
self.job_ui_port = "4040"
self.job_history_ui_port = "18080"
self.web_ui_port = "8080"
def _read_config_file(self, path: str = aztk.utils.constants.DEFAULT_SSH_CONFIG_PATH):
"""
Reads the config file in the .aztk/ directory (.aztk/ssh.yaml)
"""
if not os.path.isfile(path):
return
with open(path, "r", encoding="UTF-8") as stream:
try:
config = yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError("Error in ssh.yaml: {0}".format(err))
if config is None:
return
self._merge_dict(config)
def _merge_dict(self, config):
if config.get("username") is not None:
self.username = config["username"]
if config.get("cluster_id") is not None:
self.cluster_id = config["cluster_id"]
if config.get("job_ui_port") is not None:
self.job_ui_port = config["job_ui_port"]
if config.get("job_history_ui_port") is not None:
self.job_history_ui_port = config["job_history_ui_port"]
if config.get("web_ui_port") is not None:
self.web_ui_port = config["web_ui_port"]
if config.get("host") is not None:
self.host = config["host"]
if config.get("connect") is not None:
self.connect = config["connect"]
if config.get("internal") is not None:
self.internal = config["internal"]
def merge(self, cluster_id, username, job_ui_port, job_history_ui_port, web_ui_port, host, connect, internal):
"""
Merges fields with args object
"""
self._read_config_file(os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, ".aztk", "ssh.yaml"))
self._read_config_file()
self._merge_dict(
dict(
cluster_id=cluster_id,
username=username,
job_ui_port=job_ui_port,
job_history_ui_port=job_history_ui_port,
web_ui_port=web_ui_port,
host=host,
connect=connect,
internal=internal,
))
if self.cluster_id is None:
raise aztk.error.AztkError("Please supply an id for the cluster either in the ssh.yaml configuration file "
"or with a parameter (--id)")
if self.username is None:
raise aztk.error.AztkError(
"Please supply a username either in the ssh.yaml configuration file or with a parameter (--username)")
def _convert_to_path(path: str):
if path:
abs_path = os.path.abspath(os.path.expanduser(path))
if not os.path.exists(abs_path):
raise aztk.error.AztkError("Could not find file: {0}\nCheck your configuration file".format(path))
return abs_path
return None
class JobConfig:
def __init__(self):
self.id = None
self.applications = []
self.spark_configuration = None
self.vm_size = None
self.toolkit = None
self.max_dedicated_nodes = 0
self.max_low_pri_nodes = 0
self.spark_defaults_conf = None
self.spark_env_sh = None
self.core_site_xml = None
self.subnet_id = None
self.worker_on_master = None
self.scheduling_target = None
self.jars = []
def _merge_dict(self, config):
config = config.get("job")
if config.get("id") is not None:
self.id = config["id"]
cluster_configuration = config.get("cluster_configuration")
if cluster_configuration:
self.vm_size = cluster_configuration.get("vm_size")
self.toolkit = Toolkit.from_dict(cluster_configuration.get("toolkit"))
if cluster_configuration.get("size") is not None:
self.max_dedicated_nodes = cluster_configuration.get("size")
if cluster_configuration.get("size_low_priority") is not None:
self.max_low_pri_nodes = cluster_configuration.get("size_low_priority")
self.subnet_id = cluster_configuration.get("subnet_id")
self.worker_on_master = cluster_configuration.get("worker_on_master")
scheduling_target = cluster_configuration.get("scheduling_target")
if scheduling_target:
self.scheduling_target = SchedulingTarget(scheduling_target)
applications = config.get("applications")
if applications:
self.applications = []
for application in applications:
self.applications.append(
aztk.spark.models.ApplicationConfiguration(
name=application.get("name"),
application=application.get("application"),
application_args=application.get("application_args"),
main_class=application.get("main_class"),
jars=application.get("jars"),
py_files=application.get("py_files"),
files=application.get("files"),
driver_java_options=application.get("driver_java_options"),
driver_library_path=application.get("driver_library_path"),
driver_class_path=application.get("driver_class_path"),
driver_memory=application.get("driver_memory"),
executor_memory=application.get("executor_memory"),
driver_cores=application.get("driver_cores"),
executor_cores=application.get("executor_cores"),
))
spark_configuration = config.get("spark_configuration")
if spark_configuration:
self.spark_defaults_conf = _convert_to_path(spark_configuration.get("spark_defaults_conf"))
self.spark_env_sh = _convert_to_path(spark_configuration.get("spark_env_sh"))
self.core_site_xml = _convert_to_path(spark_configuration.get("core_site_xml"))
self.jars = [_convert_to_path(jar) for jar in spark_configuration.get("jars") or []]
def _read_config_file(self, path: str = aztk.utils.constants.DEFAULT_SPARK_JOB_CONFIG):
"""
Reads the Job config file in the .aztk/ directory (.aztk/job.yaml)
"""
if not path or not os.path.isfile(path):
return
with open(path, "r", encoding="UTF-8") as stream:
try:
config = yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError("Error in job.yaml: {0}".format(err))
if config is None:
return
self._merge_dict(config)
def merge(self, id, job_config_yaml=None):
self._read_config_file(aztk.utils.constants.GLOBAL_SPARK_JOB_CONFIG)
self._read_config_file(aztk.utils.constants.DEFAULT_SPARK_JOB_CONFIG)
self._read_config_file(job_config_yaml)
if id:
self.id = id
for entry in self.applications:
if entry.name is None:
raise aztk.error.AztkError(
"Application specified with no name. Please verify your configuration in job.yaml")
if entry.application is None:
raise aztk.error.AztkError("No path to application specified for {} in job.yaml".format(entry.name))
def get_file_if_exists(file):
local_conf_file = os.path.join(aztk.utils.constants.DEFAULT_SPARK_CONF_SOURCE, file)
global_conf_file = os.path.join(aztk.utils.constants.GLOBAL_CONFIG_PATH, file)
if os.path.exists(local_conf_file):
return local_conf_file
if os.path.exists(global_conf_file):
return global_conf_file
return None
def load_aztk_spark_config():
return aztk.spark.models.SparkConfiguration(
spark_defaults_conf=get_file_if_exists("spark-defaults.conf"),
jars=load_jars(),
spark_env_sh=get_file_if_exists("spark-env.sh"),
core_site_xml=get_file_if_exists("core-site.xml"),
)
def load_jars():
jars = None
# try load global
try:
jars_src = os.path.join(aztk.utils.constants.GLOBAL_CONFIG_PATH, "jars")
jars = [os.path.join(jars_src, jar) for jar in os.listdir(jars_src)]
except FileNotFoundError:
pass
# try load local, overwrite if found
try:
jars_src = os.path.join(aztk.utils.constants.DEFAULT_SPARK_CONF_SOURCE, "jars")
jars = [os.path.join(jars_src, jar) for jar in os.listdir(jars_src)]
except FileNotFoundError:
pass
return jars