Feature: New Models design with auto validation, default and merging (#543)

This commit is contained in:
Timothee Guerin 2018-05-30 09:07:09 -07:00 коммит произвёл GitHub
Родитель f6735cc6dd
Коммит 02f336b0a0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
50 изменённых файлов: 1360 добавлений и 734 удалений

3
.gitignore поставляемый
Просмотреть файл

@ -49,3 +49,6 @@ tmp/
# Built docs
docs/_build/
# PyCharm
.idea/

Просмотреть файл

@ -3,4 +3,4 @@ based_on_style = pep8
spaces_before_comment = 4
split_before_logical_operator = true
indent_width = 4
column_limit = 120
column_limit = 140

3
.vscode/settings.json поставляемый
Просмотреть файл

@ -16,5 +16,6 @@
"--style=.style.yapf"
],
"python.venvPath": "${workspaceFolder}/.venv/",
"python.pythonPath": "${workspaceFolder}/.venv/Scripts/python.exe"
"python.pythonPath": "${workspaceFolder}/.venv/Scripts/python.exe",
"python.unitTest.pyTestEnabled": true
}

Просмотреть файл

@ -90,7 +90,7 @@ class Client:
network_conf = batch_models.NetworkConfiguration(
subnet_id=cluster_conf.subnet_id)
auto_scale_formula = "$TargetDedicatedNodes={0}; $TargetLowPriorityNodes={1}".format(
cluster_conf.vm_count, cluster_conf.vm_low_pri_count)
cluster_conf.size, cluster_conf.size_low_priority)
# Confiure the pool
pool = batch_models.PoolAddParameter(
@ -110,7 +110,7 @@ class Client:
batch_models.MetadataItem(
name=constants.AZTK_SOFTWARE_METADATA_KEY, value=software_metadata_key),
batch_models.MetadataItem(
name=constants.AZTK_MODE_METADATA_KEY, value=constants.AZTK_CLUSTER_MODE_METADATA)
name=constants.AZTK_MODE_METADATA_KEY, value=constants.AZTK_CLUSTER_MODE_METADATA)
])
# Create the pool + create user for the pool

0
aztk/core/__init__.py Normal file
Просмотреть файл

Просмотреть файл

@ -0,0 +1,2 @@
from .model import Model
from .fields import String, Integer, Boolean, Float, List, ModelMergeStrategy, ListMergeStrategy

241
aztk/core/models/fields.py Normal file
Просмотреть файл

@ -0,0 +1,241 @@
import collections
import enum
from aztk.error import InvalidModelFieldError
from . import validators as aztk_validators
class ModelMergeStrategy(enum.Enum):
Override = 1
"""
Override the value with the other value
"""
Merge = 2
"""
Try to merge value nested
"""
class ListMergeStrategy(enum.Enum):
Replace = 1
"""
Override the value with the other value
"""
Append = 2
"""
Append all the values of the new list
"""
# pylint: disable=W0212
class Field:
"""
Base class for all model fields
"""
def __init__(self, *validators, **kwargs):
self.default = kwargs.get('default')
self.required = 'default' not in kwargs
self.validators = []
if self.required:
self.validators.append(aztk_validators.Required())
self.validators.extend(validators)
choices = kwargs.get('choices')
if choices:
self.validators.append(aztk_validators.In(choices))
def validate(self, value):
for validator in self.validators:
validator(value)
def __get__(self, instance, owner):
if instance is not None:
value = instance._data.get(self)
if value is None:
return instance._defaults.setdefault(self, self._default(instance))
return value
return self
def __set__(self, instance, value):
instance._data[self] = value
def merge(self, instance, value):
"""
Method called when merging 2 model together.
This is overriden in some of the fields where merge can be handled differently
"""
if value is not None:
instance._data[self] = value
def serialize(self, instance):
return self.__get__(instance, None)
def _default(self, model):
if callable(self.default):
return self.__call_default(model)
return self.default
def __call_default(self, *args):
try:
return self.default()
except TypeError as error:
try:
return self.default(*args)
except TypeError:
raise error
class String(Field):
"""
Model String field
"""
def __init__(self, *args, **kwargs):
super().__init__(aztk_validators.String(), *args, **kwargs)
class Integer(Field):
"""
Model Integer field
"""
def __init__(self, *args, **kwargs):
super().__init__(aztk_validators.Integer(), *args, **kwargs)
class Float(Field):
"""
Model Float field
"""
def __init__(self, *args, **kwargs):
super().__init__(aztk_validators.Float(), *args, **kwargs)
class Boolean(Field):
"""
Model Boolean field
"""
def __init__(self, *args, **kwargs):
super().__init__(aztk_validators.Boolean(), *args, **kwargs)
class List(Field):
"""
Field that should be a list
"""
def __init__(self, model=None, **kwargs):
self.model = model
kwargs.setdefault('default', list)
self.merge_strategy = kwargs.get('merge_strategy', ListMergeStrategy.Append)
self.skip_none = kwargs.get('skip_none', True)
super().__init__(
aztk_validators.List(*kwargs.get('inner_validators', [])), **kwargs)
def __set__(self, instance, value):
if isinstance(value, collections.MutableSequence):
value = self._resolve(value)
if value is None:
value = []
super().__set__(instance, value)
def _resolve(self, value):
result = []
for item in value:
if item is None and self.skip_none: # Skip none values
continue
if self.model and isinstance(item, collections.MutableMapping):
item = self.model(**item)
result.append(item)
return result
def merge(self, instance, value):
if value is None:
value = []
if self.merge_strategy == ListMergeStrategy.Append:
current = instance._data.get(self)
if current is None:
current = []
value = current + value
instance._data[self] = value
def serialize(self, instance):
items = super().serialize(instance)
output = []
if items is not None:
for item in items:
if hasattr(item, 'to_dict'):
output.append(item.to_dict())
else:
output.append(item)
return output
class Model(Field):
"""
Field is another model
Args:
model (aztk.core.models.Model): Model object that field should be
merge_strategy (ModelMergeStrategy): When merging models how should the nested model be merged.
Default: `ModelMergeStrategy.merge`
"""
def __init__(self, model, *args, **kwargs):
super().__init__(aztk_validators.Model(model), *args, **kwargs)
self.model = model
self.merge_strategy = kwargs.get('merge_strategy', ModelMergeStrategy.Merge)
def __set__(self, instance, value):
if isinstance(value, collections.MutableMapping):
value = self.model(**value)
super().__set__(instance, value)
def merge(self, instance, value):
if self.merge_strategy == ModelMergeStrategy.Merge:
current = instance._data.get(self)
if current is not None:
current.merge(value)
value = current
instance._data[self] = value
def serialize(self, instance):
val = super().serialize(instance)
if val is not None:
return val.to_dict()
else:
return None
class Enum(Field):
"""
Field that should be an enum
"""
def __init__(self, model, *args, **kwargs):
super().__init__(aztk_validators.InstanceOf(model), *args, **kwargs)
self.model = model
def __set__(self, instance, value):
if value is not None and not isinstance(value, self.model):
try:
value = self.model(value)
except ValueError:
available = [e.value for e in self.model]
raise InvalidModelFieldError("{0} is not a valid option. Use one of {1}".format(value, available))
super().__set__(instance, value)
def serialize(self, instance):
val = super().serialize(instance)
if val is not None:
return val.value
else:
return None

123
aztk/core/models/model.py Normal file
Просмотреть файл

@ -0,0 +1,123 @@
import yaml
from aztk.error import InvalidModelError, InvalidModelFieldError, AztkError, AztkAttributeError
from aztk.core.models import fields
# pylint: disable=W0212,no-member
class ModelMeta(type):
"""
Model Meta class. This takes all the class definition and build the attributes form all the fields definitions.
"""
def __new__(mcs, name, bases, attrs):
attrs['_fields'] = {}
for base in bases:
if hasattr(base, '_fields'):
for k, v in base._fields.items():
attrs['_fields'][k] = v
for k, v in base.__dict__.items():
if isinstance(v, fields.Field):
attrs['_fields'][k] = v
for k, v in attrs.items():
if isinstance(v, fields.Field):
attrs['_fields'][k] = v
return super().__new__(mcs, name, bases, attrs)
class Model(metaclass=ModelMeta):
"""
Base class for all aztk models
To implement model wide validation implement `__validate__` method
"""
def __new__(cls, *_args, **_kwargs):
model = super().__new__(cls)
model._data = {}
model._defaults = {}
return model
def __init__(self, **kwargs):
self._update(kwargs)
def __getitem__(self, k):
if k not in self._fields:
raise AztkAttributeError("{0} doesn't have an attribute called {1}".format(self.__class__.__name__, k))
return getattr(self, k)
def __setitem__(self, k, v):
if k not in self._fields:
raise AztkAttributeError("{0} doesn't have an attribute called {1}".format(self.__class__.__name__, k))
try:
setattr(self, k, v)
except InvalidModelFieldError as e:
self._process_field_error(e, k)
def __getstate__(self):
"""
For pickle serialization. This return the state of the model
"""
return self.to_dict()
def __setstate__(self, state):
"""
For pickle serialization. This update the current model with the given state
"""
self._update(state)
def validate(self):
"""
Validate the entire model
"""
for name, field in self._fields.items():
try:
field.validate(getattr(self, name))
except InvalidModelFieldError as e:
self._process_field_error(e, name)
except InvalidModelError as e:
e.model = self
raise e
if hasattr(self, '__validate__'):
self.__validate__()
def merge(self, other):
if not isinstance(other, self.__class__):
raise AztkError("Cannot merge {0} as is it not an instance of {1}".format(other, self.__class__.__name__))
for field in other._fields.values():
if field in other._data:
field.merge(self, other._data[field])
return self
@classmethod
def from_dict(cls, val: dict):
return cls(**val)
def to_dict(self):
output = dict()
for name, field in self._fields.items():
output[name] = field.serialize(self)
return output
def __str__(self):
return yaml.dump(self.to_dict(), default_flow_style=False)
def _update(self, values):
for k, v in values.items():
self[k] = v
def _process_field_error(self, e: InvalidModelFieldError, field: str):
if not e.field:
e.field = field
if not e.model:
e.model = self
raise e

Просмотреть файл

@ -0,0 +1,149 @@
import collections
from aztk.error import InvalidModelFieldError
class Validator:
"""
Base class for a validator.
To write your validator extend this class and implement the validate method.
To raise an error raise InvalidModelFieldError
"""
def __call__(self, value):
self.validate(value)
def validate(self, value):
raise NotImplementedError()
class Required(Validator):
"""
Validate the field valiue is not `None`
"""
def validate(self, value):
if value is None:
raise InvalidModelFieldError('is required')
class String(Validator):
"""
Validate the value of the field is a `str`
"""
def validate(self, value):
if not value:
return
if not isinstance(value, str):
raise InvalidModelFieldError('{0} should be a string'.format(value))
class Integer(Validator):
"""
Validate the value of the field is a `int`
"""
def validate(self, value):
if not value:
return
if not isinstance(value, int):
raise InvalidModelFieldError('{0} should be an integer'.format(value))
class Float(Validator):
"""
Validate the value of the field is a `float`
"""
def validate(self, value):
if not value:
return
if not isinstance(value, float):
raise InvalidModelFieldError('{0} should be a float'.format(value))
class Boolean(Validator):
"""This validator forces fields values to be an instance of `bool`."""
def validate(self, value):
if not value:
return
if not isinstance(value, bool):
raise InvalidModelFieldError('{0} should be a boolean'.format(value))
class In(Validator):
"""
Validate the field value is in the list of allowed choices
"""
def __init__(self, choices):
self.choices = choices
def validate(self, value):
if not value:
return
if value not in self.choices:
raise InvalidModelFieldError('{0} should be in {1}'.format(value, self.choices))
class InstanceOf(Validator):
"""
Check if the field is an instance of the given type
"""
def __init__(self, cls):
self.type = cls
def validate(self, value):
if not value:
return
if not isinstance(value, self.type):
raise InvalidModelFieldError(
"should be an instance of '{}'".format(self.type.__name__))
class Model(Validator):
"""
Validate the field is a model
"""
def __init__(self, model):
self.model = model
def validate(self, value):
if not value:
return
if not isinstance(value, self.model):
raise InvalidModelFieldError(
"should be an instance of '{}'".format(self.model.__name__))
value.validate()
class List(Validator):
"""
Validate the given item is a list
"""
def __init__(self, *validators):
self.validators = validators
def validate(self, value):
if not value:
return
if not isinstance(value, collections.MutableSequence):
raise InvalidModelFieldError('should be a list')
for i in value:
for validator in self.validators:
validator(i)

Просмотреть файл

@ -7,6 +7,10 @@ All error should inherit from `AztkError`
class AztkError(Exception):
pass
class AztkAttributeError(AztkError):
pass
class ClusterNotReadyError(AztkError):
pass
@ -17,7 +21,15 @@ class InvalidPluginConfigurationError(AztkError):
pass
class InvalidModelError(AztkError):
pass
def __init__(self, message: str, model=None):
super().__init__()
self.message = message
self.model = model
def __str__(self):
model_name = self.model and self.model.__class__.__name__
return "{model} {message}".format(model=model_name, message=self.message)
class MissingRequiredAttributeError(InvalidModelError):
pass
@ -27,3 +39,12 @@ class InvalidCustomScriptError(InvalidModelError):
class InvalidPluginReferenceError(InvalidModelError):
pass
class InvalidModelFieldError(InvalidModelError):
def __init__(self, message: str, model=None, field=None):
super().__init__(message, model)
self.field = field
def __str__(self):
model_name = self.model and self.model.__class__.__name__
return "{model} {field} {message}".format(model=model_name, field=self.field, message=self.message)

Просмотреть файл

@ -1,9 +1,13 @@
import io
import logging
import yaml
import azure.common
from .node_data import NodeData
import yaml
from aztk.models import ClusterConfiguration
from .blob_data import BlobData
from .node_data import NodeData
class ClusterData:

Просмотреть файл

@ -1,2 +1,18 @@
from .toolkit import Toolkit, TOOLKIT_MAP
from .models import *
from .cluster_configuration import ClusterConfiguration
from .custom_script import CustomScript
from .file_share import FileShare
from .toolkit import TOOLKIT_MAP, Toolkit
from .user_configuration import UserConfiguration
from .secrets_configuration import (
SecretsConfiguration,
ServicePrincipalConfiguration,
SharedKeyConfiguration,
DockerConfiguration,
)
from .file import File
from .remote_login import RemoteLogin
from .ssh_log import SSHLog
from .vm_image import VmImage
from .software import Software
from .cluster import Cluster
from .plugins import *

23
aztk/models/cluster.py Normal file
Просмотреть файл

@ -0,0 +1,23 @@
import azure.batch.models as batch_models
class Cluster:
def __init__(self,
pool: batch_models.CloudPool,
nodes: batch_models.ComputeNodePaged = None):
self.id = pool.id
self.pool = pool
self.nodes = nodes
self.vm_size = pool.vm_size
if pool.state.value is batch_models.PoolState.active:
self.visible_state = pool.allocation_state.value
else:
self.visible_state = pool.state.value
self.total_current_nodes = pool.current_dedicated_nodes + \
pool.current_low_priority_nodes
self.total_target_nodes = pool.target_dedicated_nodes + \
pool.target_low_priority_nodes
self.current_dedicated_nodes = pool.current_dedicated_nodes
self.current_low_pri_nodes = pool.current_low_priority_nodes
self.target_dedicated_nodes = pool.target_dedicated_nodes
self.target_low_pri_nodes = pool.target_low_priority_nodes

Просмотреть файл

@ -0,0 +1,83 @@
import aztk.error as error
from aztk.core.models import Model, fields
from aztk.utils import deprecate, helpers
from .custom_script import CustomScript
from .file_share import FileShare
from .plugins import PluginConfiguration
from .toolkit import Toolkit
from .user_configuration import UserConfiguration
class ClusterConfiguration(Model):
"""
Cluster configuration model
Args:
cluster_id (str): Id of the Aztk cluster
toolkit (aztk.models.Toolkit): Toolkit to be used for this cluster
size (int): Number of dedicated nodes for this cluster
size_low_priority (int): Number of low priority nodes for this cluster
vm_size (int): Azure Vm size to be used for each node
subnet_id (str): Full resource id of the subnet to be used(Required for mixed mode clusters)
plugins (List[aztk.models.plugins.PluginConfiguration]): List of plugins to be used
file_shares (List[aztk.models.FileShare]): List of File shares to be used
user_configuration (aztk.models.UserConfiguration): Configuration of the user to
be created on the master node to ssh into.
"""
cluster_id = fields.String()
toolkit = fields.Model(Toolkit)
size = fields.Integer(default=0)
size_low_priority = fields.Integer(default=0)
vm_size = fields.String()
subnet_id = fields.String(default=None)
plugins = fields.List(PluginConfiguration)
custom_scripts = fields.List(CustomScript)
file_shares = fields.List(FileShare)
user_configuration = fields.Model(UserConfiguration, default=None)
def __init__(self, *args, **kwargs):
if 'vm_count' in kwargs:
deprecate("vm_count is deprecated for ClusterConfiguration please use size instead")
kwargs['size'] = kwargs.pop('vm_count')
if 'vm_low_pri_count' in kwargs:
deprecate("vm_low_pri_count is deprecated for ClusterConfiguration please use size_low_priority instead")
kwargs['size_low_priority'] = kwargs.pop('vm_low_pri_count')
super().__init__(*args, **kwargs)
def mixed_mode(self) -> bool:
"""
Return:
if the pool is using mixed mode(Both dedicated and low priority nodes)
"""
return self.size > 0 and self.size_low_priority > 0
def gpu_enabled(self):
return helpers.is_gpu_enabled(self.vm_size)
def get_docker_repo(self):
return self.toolkit.get_docker_repo(self.gpu_enabled())
def __validate__(self) -> bool:
if self.size == 0 and self.size_low_priority == 0:
raise error.InvalidModelError(
"Please supply a valid (greater than 0) size or size_low_priority value either in the cluster.yaml configuration file or with a parameter (--size or --size-low-pri)"
)
if self.vm_size is None:
raise error.InvalidModelError(
"Please supply a vm_size in either the cluster.yaml configuration file or with a parameter (--vm-size)"
)
if self.mixed_mode() and not self.subnet_id:
raise error.InvalidModelError(
"You must configure a VNET to use AZTK in mixed mode (dedicated and low priority nodes). Set the VNET's subnet_id in your cluster.yaml."
)
if self.custom_scripts:
deprecate("Custom scripts are DEPRECATED and will be removed in 0.8.0. Use plugins instead See https://aztk.readthedocs.io/en/v0.7.0/15-plugins.html")

Просмотреть файл

@ -0,0 +1,6 @@
from aztk.core.models import Model, fields
class CustomScript(Model):
name = fields.String()
script = fields.String()
run_on = fields.String()

6
aztk/models/file.py Normal file
Просмотреть файл

@ -0,0 +1,6 @@
import io
class File:
def __init__(self, name: str, payload: io.StringIO):
self.name = name
self.payload = payload

Просмотреть файл

@ -0,0 +1,7 @@
from aztk.core.models import Model, fields
class FileShare(Model):
storage_account_name = fields.String()
storage_account_key = fields.String()
file_share_path = fields.String()
mount_path = fields.String()

Просмотреть файл

@ -1,312 +0,0 @@
import io
from typing import List
import azure.batch.models as batch_models
from aztk import error
from aztk.utils import helpers, deprecate
from aztk.models.plugins import PluginConfiguration
from aztk.internal import ConfigurationBase
from .toolkit import Toolkit
class FileShare:
def __init__(self,
storage_account_name: str = None,
storage_account_key: str = None,
file_share_path: str = None,
mount_path: str = None):
self.storage_account_name = storage_account_name
self.storage_account_key = storage_account_key
self.file_share_path = file_share_path
self.mount_path = mount_path
class File:
def __init__(self, name: str, payload: io.StringIO):
self.name = name
self.payload = payload
class CustomScript:
def __init__(self, name: str = None, script=None, run_on=None):
self.name = name
self.script = script
self.run_on = run_on
class UserConfiguration(ConfigurationBase):
def __init__(self,
username: str,
ssh_key: str = None,
password: str = None):
self.username = username
self.ssh_key = ssh_key
self.password = password
def merge(self, other):
self._merge_attributes(other, [
"username",
"ssh_key",
"password",
])
def validate(self):
pass
class ClusterConfiguration(ConfigurationBase):
"""
Cluster configuration model
Args:
toolkit
"""
def __init__(self,
toolkit: Toolkit = None,
custom_scripts: List[CustomScript] = None,
file_shares: List[FileShare] = None,
cluster_id: str = None,
vm_count=0,
vm_low_pri_count=0,
vm_size=None,
subnet_id=None,
plugins: List[PluginConfiguration] = None,
user_configuration: UserConfiguration = None):
super().__init__()
self.toolkit = toolkit
self.custom_scripts = custom_scripts
self.file_shares = file_shares
self.cluster_id = cluster_id
self.vm_count = vm_count
self.vm_size = vm_size
self.vm_low_pri_count = vm_low_pri_count
self.subnet_id = subnet_id
self.user_configuration = user_configuration
self.plugins = plugins
def merge(self, other):
"""
Merge other cluster config into this one.
:params other: ClusterConfiguration
"""
self._merge_attributes(other, [
"toolkit",
"custom_scripts",
"file_shares",
"cluster_id",
"vm_size",
"subnet_id",
"vm_count",
"vm_low_pri_count",
"plugins",
])
if other.user_configuration:
if self.user_configuration:
self.user_configuration.merge(other.user_configuration)
else:
self.user_configuration = other.user_configuration
if self.plugins:
for plugin in self.plugins:
plugin.validate()
def mixed_mode(self) -> bool:
return self.vm_count > 0 and self.vm_low_pri_count > 0
def gpu_enabled(self):
return helpers.is_gpu_enabled(self.vm_size)
def get_docker_repo(self):
return self.toolkit.get_docker_repo(self.gpu_enabled())
def validate(self) -> bool:
"""
Validate the config at its current state.
Raises: Error if invalid
"""
if self.toolkit is None:
raise error.InvalidModelError(
"Please supply a toolkit for the cluster")
self.toolkit.validate()
if self.cluster_id is None:
raise error.AztkError(
"Please supply an id for the cluster with a parameter (--id)")
if self.vm_count == 0 and self.vm_low_pri_count == 0:
raise error.AztkError(
"Please supply a valid (greater than 0) size or size_low_pri value either in the cluster.yaml configuration file or with a parameter (--size or --size-low-pri)"
)
if self.vm_size is None:
raise error.AztkError(
"Please supply a vm_size in either the cluster.yaml configuration file or with a parameter (--vm-size)"
)
if self.mixed_mode() and not self.subnet_id:
raise error.AztkError(
"You must configure a VNET to use AZTK in mixed mode (dedicated and low priority nodes). Set the VNET's subnet_id in your cluster.yaml."
)
if self.custom_scripts:
deprecate("Custom scripts are DEPRECATED and will be removed in 0.8.0. Use plugins instead See https://aztk.readthedocs.io/en/v0.7.0/15-plugins.html")
class RemoteLogin:
def __init__(self, ip_address, port):
self.ip_address = ip_address
self.port = port
class ServicePrincipalConfiguration(ConfigurationBase):
"""
Container class for AAD authentication
"""
def __init__(self,
tenant_id: str = None,
client_id: str = None,
credential: str = None,
batch_account_resource_id: str = None,
storage_account_resource_id: str = None):
self.tenant_id = tenant_id
self.client_id = client_id
self.credential = credential
self.batch_account_resource_id = batch_account_resource_id
self.storage_account_resource_id = storage_account_resource_id
def validate(self) -> bool:
"""
Validate the config at its current state.
Raises: Error if invalid
"""
self._validate_required([
"tenant_id",
"client_id",
"credential",
"batch_account_resource_id",
"storage_account_resource_id",
])
class SharedKeyConfiguration(ConfigurationBase):
"""
Container class for shared key authentication
"""
def __init__(self,
batch_account_name: str = None,
batch_account_key: str = None,
batch_service_url: str = None,
storage_account_name: str = None,
storage_account_key: str = None,
storage_account_suffix: str = None):
self.batch_account_name = batch_account_name
self.batch_account_key = batch_account_key
self.batch_service_url = batch_service_url
self.storage_account_name = storage_account_name
self.storage_account_key = storage_account_key
self.storage_account_suffix = storage_account_suffix
def validate(self) -> bool:
"""
Validate the config at its current state.
Raises: Error if invalid
"""
self._validate_required([
"batch_account_name",
"batch_account_key",
"batch_service_url",
"storage_account_name",
"storage_account_key",
"storage_account_suffix",
])
class DockerConfiguration(ConfigurationBase):
def __init__(self, endpoint=None, username=None, password=None):
self.endpoint = endpoint
self.username = username
self.password = password
def validate(self):
pass
class SecretsConfiguration(ConfigurationBase):
def __init__(self,
service_principal=None,
shared_key=None,
docker=None,
ssh_pub_key=None,
ssh_priv_key=None):
self.service_principal = service_principal
self.shared_key = shared_key
self.docker = docker
self.ssh_pub_key = ssh_pub_key
self.ssh_priv_key = ssh_priv_key
def validate(self):
if self.service_principal and self.shared_key:
raise error.AztkError(
"Both service_principal and shared_key auth are configured, must use only one"
)
elif self.service_principal:
self.service_principal.validate()
elif self.shared_key:
self.shared_key.validate()
else:
raise error.AztkError(
"Neither service_principal and shared_key auth are configured, must use only one"
)
def is_aad(self):
return self.service_principal is not None
class VmImage:
def __init__(self, publisher, offer, sku):
self.publisher = publisher
self.offer = offer
self.sku = sku
class Cluster:
def __init__(self,
pool: batch_models.CloudPool,
nodes: batch_models.ComputeNodePaged = None):
self.id = pool.id
self.pool = pool
self.nodes = nodes
self.vm_size = pool.vm_size
if pool.state.value is batch_models.PoolState.active:
self.visible_state = pool.allocation_state.value
else:
self.visible_state = pool.state.value
self.total_current_nodes = pool.current_dedicated_nodes + \
pool.current_low_priority_nodes
self.total_target_nodes = pool.target_dedicated_nodes + \
pool.target_low_priority_nodes
self.current_dedicated_nodes = pool.current_dedicated_nodes
self.current_low_pri_nodes = pool.current_low_priority_nodes
self.target_dedicated_nodes = pool.target_dedicated_nodes
self.target_low_pri_nodes = pool.target_low_priority_nodes
class SSHLog():
def __init__(self, output, node_id):
self.output = output
self.node_id = node_id
class Software:
"""
Enum with list of available softwares
"""
spark = "spark"

Просмотреть файл

@ -46,9 +46,9 @@ class PluginManager:
def get_args_for(self, cls):
signature = inspect.signature(cls)
args = dict()
for k, v in signature.parameters.items():
args[k] = PluginArgument(k, default=v.default, required=v.default is inspect.Parameter.empty)
for key, param in signature.parameters.items():
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.KEYWORD_ONLY:
args[key] = PluginArgument(key, default=param.default, required=param.default is inspect.Parameter.empty)
return args

Просмотреть файл

@ -1,14 +1,14 @@
import os
from aztk.error import InvalidModelError
from aztk.internal import ConfigurationBase
from aztk.core.models import Model, fields
from aztk.models import PluginConfiguration
from aztk.models.plugins import PluginFile, PluginTarget, PluginTargetRole
from .plugin_manager import plugin_manager
class PluginReference(ConfigurationBase):
class PluginReference(Model):
"""
Contains the configuration to use a plugin
@ -20,27 +20,12 @@ class PluginReference(ConfigurationBase):
target_role (PluginTargetRole): Target role default to All nodes. This can only be used if providing a script
args: (dict): If using name this is the arguments to pass to the plugin
"""
def __init__(self,
name: str = None,
script: str = None,
target: PluginTarget = None,
target_role: PluginTargetRole = None,
args: dict = None):
super().__init__()
self.name = name
self.script = script
self.target = target
self.target_role = target_role
self.args = args or dict()
@classmethod
def _from_dict(cls, args: dict):
if "target" in args:
args["target"] = PluginTarget(args["target"])
if "target_role" in args:
args["target_role"] = PluginTargetRole(args["target_role"])
return super()._from_dict(args)
name = fields.String(default=None)
script = fields.String(default=None)
target = fields.Enum(PluginTarget, default=None)
target_role = fields.Enum(PluginTargetRole, default=None)
args = fields.String(default=None)
def get_plugin(self) -> PluginConfiguration:
self.validate()
@ -50,7 +35,7 @@ class PluginReference(ConfigurationBase):
return plugin_manager.get_plugin(self.name, self.args)
def validate(self) -> bool:
def __validate__(self):
if not self.name and not self.script:
raise InvalidModelError("Plugin must either specify a name of an existing plugin or the path to a script.")

Просмотреть файл

@ -1,9 +1,10 @@
from enum import Enum
from typing import List, Union
from aztk.internal import ConfigurationBase
from aztk.error import InvalidPluginConfigurationError
from aztk.core.models import Model, fields
from .plugin_file import PluginFile
class PluginTarget(Enum):
"""
Where this plugin should run
@ -18,74 +19,51 @@ class PluginTargetRole(Enum):
All = "all-nodes"
class PluginPort:
class PluginPort(Model):
"""
Definition for a port that should be opened on node
:param internal: Port on the node
:param public: [Optional] Port available to the user. If none won't open any port to the user
:param name: [Optional] name to differentiate ports if you have multiple
"""
internal = fields.Integer()
public = fields.Field(default=None)
name = fields.Integer()
def __init__(self, internal: int, public: Union[int, bool]=False, name=None):
self.internal = internal
self.expose_publicly = bool(public)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expose_publicly = bool(self.public)
self.public_port = None
if self.expose_publicly:
if public is True:
self.public_port = internal
if self.public is True:
self.public_port = self.internal
else:
self.public_port = public
self.name = name
self.public_port = self.public
class PluginConfiguration(ConfigurationBase):
class PluginConfiguration(Model):
"""
Plugin manifest that should be returned in the main.py of your plugin
:param name: Name of the plugin. Used to reference the plugin
:param runOn: Where the plugin should run
:param files: List of files to upload
:param args:
:param env:
"""
def __init__(self,
name: str,
ports: List[PluginPort] = None,
files: List[PluginFile] = None,
execute: str = None,
args=None,
env=None,
target_role: PluginTargetRole = None,
target: PluginTarget = None):
self.name = name
# self.docker_image = docker_image
self.target = target or PluginTarget.SparkContainer
self.target_role = target_role or PluginTargetRole.Master
self.ports = ports or []
self.files = files or []
self.args = args or []
self.env = env or dict()
self.execute = execute
Args
name: Name of the plugin. Used to reference the plugin
runOn: Where the plugin should run
execute: Path to the file to execute(This must match the target of one of the files)
files: List of files to upload
args: List of argumenets to pass to the executing script
env: Dict of environment variables to pass to the script
"""
name = fields.String()
files = fields.List(PluginFile)
execute = fields.String()
args = fields.List(default=[])
env = fields.List(default=[])
target = fields.Enum(PluginTarget, default=PluginTarget.SparkContainer)
target_role = fields.Enum(PluginTargetRole, default=PluginTargetRole.Master)
ports = fields.List(PluginPort, default=[])
def has_arg(self, name: str):
for x in self.args:
if x.name == name:
return True
return False
def validate(self):
self._validate_required([
"name",
"execute",
])
if not isinstance(self.target, PluginTarget):
raise InvalidPluginConfigurationError(
"Target must be of type Plugin target but was {0}".format(self.target))
if not isinstance(self.target_role, PluginTargetRole):
raise InvalidPluginConfigurationError(
"Target role must be of type Plugin target role but was {0}".format(self.target))

Просмотреть файл

@ -1,24 +1,36 @@
import io
from typing import Union
from aztk.core.models import Model, fields
class PluginFile:
class PluginFile(Model):
"""
Reference to a file for a plugin.
"""
def __init__(self, target: str, local_path: str):
self.target = target
self.local_path = local_path
target = fields.String()
local_path = fields.String()
# TODO handle folders?
def __init__(self, target: str = None, local_path: str = None):
super().__init__(target=target, local_path=local_path)
def content(self):
with open(self.local_path, "r", encoding='UTF-8') as f:
return f.read()
class TextPluginFile:
class TextPluginFile(Model):
"""
Reference to a file for a plugin.
Args:
target (str): Where should the file be uploaded relative to the plugin working dir
content (str|io.StringIO): Content of the file. Can either be a string or a StringIO
"""
target = fields.String()
def __init__(self, target: str, content: Union[str,io.StringIO]):
super().__init__(target=target)
if isinstance(content, str):
self._content = content
else:

Просмотреть файл

@ -0,0 +1,4 @@
class RemoteLogin:
def __init__(self, ip_address, port):
self.ip_address = ip_address
self.port = port

Просмотреть файл

@ -0,0 +1,59 @@
from aztk.core.models import Model, fields
from aztk.error import InvalidModelError
class ServicePrincipalConfiguration(Model):
"""
Container class for AAD authentication
"""
tenant_id = fields.String()
client_id = fields.String()
credential = fields.String()
batch_account_resource_id = fields.String()
storage_account_resource_id = fields.String()
class SharedKeyConfiguration(Model):
"""
Container class for shared key authentication
"""
batch_account_name = fields.String()
batch_account_key = fields.String()
batch_service_url = fields.String()
storage_account_name = fields.String()
storage_account_key = fields.String()
storage_account_suffix = fields.String()
class DockerConfiguration(Model):
"""
Configuration for connecting to private docker
Args:
endpoint (str): Which docker endpoint to use. Default to docker hub.
username (str): Docker endpoint username
password (str): Docker endpoint password
"""
endpoint = fields.String(default=None)
username = fields.String(default=None)
password = fields.String(default=None)
class SecretsConfiguration(Model):
service_principal = fields.Model(ServicePrincipalConfiguration, default=None)
shared_key = fields.Model(SharedKeyConfiguration, default=None)
docker = fields.Model(DockerConfiguration, default=None)
ssh_pub_key = fields.String(default=None)
ssh_priv_key = fields.String(default=None)
def __validate__(self):
if self.service_principal and self.shared_key:
raise InvalidModelError(
"Both service_principal and shared_key auth are configured, must use only one"
)
if not self.service_principal and not self.shared_key:
raise InvalidModelError(
"Neither service_principal and shared_key auth are configured, must use only one"
)
def is_aad(self):
return self.service_principal is not None

5
aztk/models/software.py Normal file
Просмотреть файл

@ -0,0 +1,5 @@
class Software:
"""
Enum with list of available softwares
"""
spark = "spark"

4
aztk/models/ssh_log.py Normal file
Просмотреть файл

@ -0,0 +1,4 @@
class SSHLog():
def __init__(self, output, node_id):
self.output = output
self.node_id = node_id

Просмотреть файл

@ -1,7 +1,6 @@
from aztk.internal import ConfigurationBase
from aztk.error import InvalidModelError
from aztk.utils import constants, deprecate
from aztk.core.models import Model, fields
class ToolkitDefinition:
def __init__(self, versions, environments):
@ -26,7 +25,7 @@ TOOLKIT_MAP = dict(
)
class Toolkit(ConfigurationBase):
class Toolkit(Model):
"""
Toolkit for a cluster.
This will help pick the docker image needed
@ -36,24 +35,16 @@ class Toolkit(ConfigurationBase):
version (str): Version of the toolkit
environment (str): Which environment to use for this toolkit
environment_version (str): If there is multiple version for an environment you can specify which one
docker_repo (str): Optional docker repo
"""
def __init__(self,
software: str,
version: str,
environment: str = None,
environment_version: str = None,
docker_repo=None):
self.software = software
self.version = str(version)
self.environment = environment
self.environment_version = environment_version
self.docker_repo = docker_repo
def validate(self):
self._validate_required(["software", "version"])
software = fields.String()
version = fields.String()
environment = fields.String(default=None)
environment_version = fields.String(default=None)
docker_repo = fields.String(default=None)
def __validate__(self):
if self.software not in TOOLKIT_MAP:
raise InvalidModelError("Toolkit '{0}' is not in the list of allowed toolkits {1}".format(
self.software, list(TOOLKIT_MAP.keys())))

Просмотреть файл

@ -0,0 +1,6 @@
from aztk.core.models import Model, fields
class UserConfiguration(Model):
username = fields.String()
ssh_key = fields.String(default=None)
password = fields.String(default=None)

5
aztk/models/vm_image.py Normal file
Просмотреть файл

@ -0,0 +1,5 @@
class VmImage:
def __init__(self, publisher, offer, sku):
self.publisher = publisher
self.offer = offer
self.sku = sku

Просмотреть файл

@ -13,11 +13,6 @@ from aztk.spark.helpers import cluster_diagnostic_helper
from aztk.spark.utils import util
from aztk.internal.cluster_data import NodeData
DEFAULT_CLUSTER_CONFIG = models.ClusterConfiguration(
worker_on_master=True,
)
class Client(BaseClient):
"""
Aztk Spark Client
@ -29,7 +24,7 @@ class Client(BaseClient):
def __init__(self, secrets_config):
super().__init__(secrets_config)
def create_cluster(self, configuration: models.ClusterConfiguration, wait: bool = False):
def create_cluster(self, cluster_conf: models.ClusterConfiguration, wait: bool = False):
"""
Create a new aztk spark cluster
@ -40,10 +35,8 @@ class Client(BaseClient):
Returns:
aztk.spark.models.Cluster
"""
cluster_conf = models.ClusterConfiguration()
cluster_conf.merge(DEFAULT_CLUSTER_CONFIG)
cluster_conf.merge(configuration)
cluster_conf.validate()
cluster_data = self._get_cluster_data(cluster_conf.cluster_id)
try:
zip_resource_files = None

Просмотреть файл

@ -4,6 +4,7 @@ import azure.batch.models as batch_models
import aztk.models
from aztk import error
from aztk.utils import constants, helpers
from aztk.core.models import Model, fields
class SparkToolkit(aztk.models.Toolkit):
def __init__(self, version: str, environment: str = None, environment_version: str = None):
@ -53,12 +54,14 @@ class File(aztk.models.File):
pass
class SparkConfiguration():
def __init__(self, spark_defaults_conf=None, spark_env_sh=None, core_site_xml=None, jars=None):
self.spark_defaults_conf = spark_defaults_conf
self.spark_env_sh = spark_env_sh
self.core_site_xml = core_site_xml
self.jars = jars
class SparkConfiguration(Model):
spark_defaults_conf = fields.String()
spark_env_sh = fields.String()
core_site_xml = fields.String()
jars = fields.List()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ssh_key_pair = self.__generate_ssh_key_pair()
def __generate_ssh_key_pair(self):
@ -96,37 +99,8 @@ class PluginConfiguration(aztk.models.PluginConfiguration):
class ClusterConfiguration(aztk.models.ClusterConfiguration):
def __init__(
self,
custom_scripts: List[CustomScript] = None,
file_shares: List[FileShare] = None,
cluster_id: str = None,
vm_count=0,
vm_low_pri_count=0,
vm_size=None,
subnet_id=None,
toolkit: SparkToolkit = None,
user_configuration: UserConfiguration = None,
spark_configuration: SparkConfiguration = None,
worker_on_master: bool = None):
super().__init__(
custom_scripts=custom_scripts,
cluster_id=cluster_id,
vm_count=vm_count,
vm_low_pri_count=vm_low_pri_count,
vm_size=vm_size,
toolkit=toolkit,
subnet_id=subnet_id,
file_shares=file_shares,
user_configuration=user_configuration,
)
self.spark_configuration = spark_configuration
self.worker_on_master = worker_on_master
def merge(self, other):
super().merge(other)
self._merge_attributes(other, ["spark_configuration", "worker_on_master"])
spark_configuration = fields.Model(SparkConfiguration, default=None)
worker_on_master = fields.Boolean(default=True)
class SecretsConfiguration(aztk.models.SecretsConfiguration):
pass
@ -234,8 +208,8 @@ class JobConfiguration:
custom_scripts=self.custom_scripts,
toolkit=self.toolkit,
vm_size=self.vm_size,
vm_count=self.max_dedicated_nodes,
vm_low_pri_count=self.max_low_pri_nodes,
size=self.max_dedicated_nodes,
size_low_priority=self.max_low_pri_nodes,
subnet_id=self.subnet_id,
worker_on_master=self.worker_on_master,
spark_configuration=self.spark_configuration,

Просмотреть файл

@ -4,19 +4,18 @@ from aztk.models.plugins.plugin_file import PluginFile
dir_path = os.path.dirname(os.path.realpath(__file__))
class JupyterPlugin(PluginConfiguration):
def __init__(self):
super().__init__(
name="jupyter",
ports=[
PluginPort(
internal=8888,
public=True,
),
],
target_role=PluginTargetRole.All,
execute="jupyter.sh",
files=[
PluginFile("jupyter.sh", os.path.join(dir_path, "jupyter.sh")),
],
)
def JupyterPlugin():
return PluginConfiguration(
name="jupyter",
ports=[
PluginPort(
internal=8888,
public=True,
),
],
target_role=PluginTargetRole.All,
execute="jupyter.sh",
files=[
PluginFile("jupyter.sh", os.path.join(dir_path, "jupyter.sh")),
],
)

Просмотреть файл

@ -5,19 +5,18 @@ from aztk.utils import constants
dir_path = os.path.dirname(os.path.realpath(__file__))
class JupyterLabPlugin(PluginConfiguration):
def __init__(self):
super().__init__(
name="jupyterlab",
ports=[
PluginPort(
internal=8889,
public=True,
),
],
target_role=PluginTargetRole.All,
execute="jupyter_lab.sh",
files=[
PluginFile("jupyter_lab.sh", os.path.join(dir_path, "jupyter_lab.sh")),
],
)
def JupyterLabPlugin():
return PluginConfiguration(
name="jupyterlab",
ports=[
PluginPort(
internal=8889,
public=True,
),
],
target_role=PluginTargetRole.All,
execute="jupyter_lab.sh",
files=[
PluginFile("jupyter_lab.sh", os.path.join(dir_path, "jupyter_lab.sh")),
],
)

Просмотреть файл

@ -1,24 +1,23 @@
import os
from aztk.models.plugins.plugin_configuration import PluginConfiguration, PluginPort, PluginTargetRole
from aztk.models.plugins.plugin_file import PluginFile
from aztk.utils import constants
dir_path = os.path.dirname(os.path.realpath(__file__))
class RStudioServerPlugin(PluginConfiguration):
def __init__(self, version="1.1.383"):
super().__init__(
name="rstudio_server",
ports=[
PluginPort(
internal=8787,
public=True,
),
],
target_role=PluginTargetRole.Master,
execute="rstudio_server.sh",
files=[
PluginFile("rstudio_server.sh", os.path.join(dir_path, "rstudio_server.sh")),
],
env=dict(RSTUDIO_SERVER_VERSION=version),
)
def RStudioServerPlugin(version="1.1.383"):
return PluginConfiguration(
name="rstudio_server",
ports=[
PluginPort(
internal=8787,
public=True,
),
],
target_role=PluginTargetRole.Master,
execute="rstudio_server.sh",
files=[
PluginFile("rstudio_server.sh", os.path.join(dir_path, "rstudio_server.sh")),
],
env=dict(RSTUDIO_SERVER_VERSION=version),
)

Просмотреть файл

@ -1,15 +1,11 @@
import os
import yaml
from aztk_cli import log
import aztk.spark
from aztk.spark.models import (
SecretsConfiguration,
ServicePrincipalConfiguration,
SharedKeyConfiguration,
DockerConfiguration,
ClusterConfiguration,
UserConfiguration,
)
from aztk.utils import deprecate
from aztk.models import Toolkit
from aztk.models.plugins.internal import PluginReference
@ -19,11 +15,9 @@ def load_aztk_secrets() -> SecretsConfiguration:
"""
secrets = SecretsConfiguration()
# read global ~/secrets.yaml
global_config = _load_secrets_config(
os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, '.aztk',
'secrets.yaml'))
global_config = _load_config_file(os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, '.aztk', 'secrets.yaml'))
# read current working directory secrets.yaml
local_config = _load_secrets_config()
local_config = _load_config_file(aztk.utils.constants.DEFAULT_SECRETS_PATH)
if not global_config and not local_config:
raise aztk.error.AztkError("There is no secrets.yaml in either ./.aztk/secrets.yaml or .aztk/secrets.yaml")
@ -37,12 +31,7 @@ def load_aztk_secrets() -> SecretsConfiguration:
secrets.validate()
return secrets
def _load_secrets_config(
path: str = aztk.utils.constants.DEFAULT_SECRETS_PATH):
"""
Loads the secrets.yaml file in the .aztk directory
"""
def _load_config_file(path: str):
if not os.path.isfile(path):
return None
@ -51,74 +40,16 @@ def _load_secrets_config(
return yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError(
"Error in secrets.yaml: {0}".format(err))
"Error in {0}:\n {1}".format(path, err))
def _merge_secrets_dict(secrets: SecretsConfiguration, secrets_config):
service_principal_config = secrets_config.get('service_principal')
if service_principal_config:
secrets.service_principal = ServicePrincipalConfiguration(
tenant_id=service_principal_config.get('tenant_id'),
client_id=service_principal_config.get('client_id'),
credential=service_principal_config.get('credential'),
batch_account_resource_id=service_principal_config.get(
'batch_account_resource_id'),
storage_account_resource_id=service_principal_config.get(
'storage_account_resource_id'),
)
shared_key_config = secrets_config.get('shared_key')
batch = secrets_config.get('batch')
storage = secrets_config.get('storage')
if shared_key_config and (batch or storage):
raise aztk.error.AztkError(
"Shared keys must be configured either under 'sharedKey:' or under 'batch:' and 'storage:', not both."
)
if shared_key_config:
secrets.shared_key = SharedKeyConfiguration(
batch_account_name=shared_key_config.get('batch_account_name'),
batch_account_key=shared_key_config.get('batch_account_key'),
batch_service_url=shared_key_config.get('batch_service_url'),
storage_account_name=shared_key_config.get('storage_account_name'),
storage_account_key=shared_key_config.get('storage_account_key'),
storage_account_suffix=shared_key_config.get(
'storage_account_suffix'),
)
elif batch or storage:
secrets.shared_key = SharedKeyConfiguration()
if batch:
log.warning(
"Your secrets.yaml format is deprecated. To use shared key authentication use the shared_key key. See config/secrets.yaml.template"
)
secrets.shared_key.batch_account_name = batch.get(
'batchaccountname')
secrets.shared_key.batch_account_key = batch.get('batchaccountkey')
secrets.shared_key.batch_service_url = batch.get('batchserviceurl')
if storage:
secrets.shared_key.storage_account_name = storage.get(
'storageaccountname')
secrets.shared_key.storage_account_key = storage.get(
'storageaccountkey')
secrets.shared_key.storage_account_suffix = storage.get(
'storageaccountsuffix')
docker_config = secrets_config.get('docker')
if docker_config:
secrets.docker = DockerConfiguration(
endpoint=docker_config.get('endpoint'),
username=docker_config.get('username'),
password=docker_config.get('password'),
)
default_config = secrets_config.get('default')
# Check for ssh keys if they are provided
if default_config:
secrets.ssh_priv_key = default_config.get('ssh_priv_key')
secrets.ssh_pub_key = default_config.get('ssh_pub_key')
if 'default' in secrets_config:
deprecate("default key in secrets.yaml is deprecated. Place all child parameters directly at the root")
secrets_config = dict(**secrets_config, **secrets_config.pop('default'))
other = SecretsConfiguration.from_dict(secrets_config)
secrets.merge(other)
def read_cluster_config(
path: str = aztk.utils.constants.DEFAULT_CLUSTER_CONFIG_PATH
@ -126,82 +57,25 @@ def read_cluster_config(
"""
Reads the config file in the .aztk/ directory (.aztk/cluster.yaml)
"""
if not os.path.isfile(path):
return None
with open(path, 'r', encoding='UTF-8') as stream:
try:
config_dict = yaml.load(stream)
except yaml.YAMLError as err:
raise aztk.error.AztkError(
"Error in cluster.yaml: {0}".format(err))
if config_dict is None:
return None
return cluster_config_from_dict(config_dict)
config_dict = _load_config_file(path)
return cluster_config_from_dict(config_dict)
def cluster_config_from_dict(config: dict):
output = ClusterConfiguration()
wait = False
if config.get('id') is not None:
output.cluster_id = config['id']
if config.get('vm_size') is not None:
output.vm_size = config['vm_size']
if config.get('size'):
output.vm_count = config['size']
if config.get('size_low_pri'):
output.vm_low_pri_count = config['size_low_pri']
if config.get('subnet_id') is not None:
output.subnet_id = config['subnet_id']
if config.get('username') is not None:
output.user_configuration = UserConfiguration(
username=config['username'])
if config.get('password') is not None:
output.user_configuration.password = config['password']
if config.get('custom_scripts') not in [[None], None]:
output.custom_scripts = []
for custom_script in config['custom_scripts']:
output.custom_scripts.append(
aztk.spark.models.CustomScript(
script=custom_script['script'],
run_on=custom_script['runOn']))
if config.get('azure_files') not in [[None], None]:
output.file_shares = []
for file_share in config['azure_files']:
output.file_shares.append(
aztk.spark.models.FileShare(
storage_account_name=file_share['storage_account_name'],
storage_account_key=file_share['storage_account_key'],
file_share_path=file_share['file_share_path'],
mount_path=file_share['mount_path'],
))
if config.get('toolkit') is not None:
output.toolkit = Toolkit.from_dict(config['toolkit'])
if config.get('plugins') not in [[None], None]:
output.plugins = []
plugins = []
for plugin in config['plugins']:
ref = PluginReference.from_dict(plugin)
output.plugins.append(ref.get_plugin())
plugins.append(ref.get_plugin())
config["plugins"] = plugins
if config.get('worker_on_master') is not None:
output.worker_on_master = config['worker_on_master']
if config.get('username') is not None:
config['user_configuration'] = dict(username=config.pop('username'))
if config.get('wait') is not None:
wait = config['wait']
wait = config.pop('wait')
return output, wait
return ClusterConfiguration.from_dict(config), wait
class SshConfig:
@ -321,8 +195,8 @@ class JobConfig():
self.toolkit = Toolkit.from_dict(cluster_configuration.get('toolkit'))
if cluster_configuration.get('size') is not None:
self.max_dedicated_nodes = cluster_configuration.get('size')
if cluster_configuration.get('size_low_pri') is not None:
self.max_low_pri_nodes = cluster_configuration.get('size_low_pri')
if cluster_configuration.get('size_low_priority') is not None:
self.max_low_pri_nodes = cluster_configuration.get('size_low_priority')
self.custom_scripts = cluster_configuration.get('custom_scripts')
self.subnet_id = cluster_configuration.get('subnet_id')
self.worker_on_master = cluster_configuration.get("worker_on_master")

Просмотреть файл

@ -21,7 +21,7 @@ vm_size: standard_f2
# size: <number of dedicated nodes in the cluster, not that clusters must contain all dedicated or all low priority nodes>
size: 2
# size_low_pri: <number of low priority nodes in the cluster, mutually exclusive with size setting>
# size_low_priority: <number of low priority nodes in the cluster, mutually exclusive with size setting>
# username: <username for the linux user to be created> (optional)

Просмотреть файл

@ -8,7 +8,7 @@ job:
cluster_configuration:
vm_size: standard_d2_v2
size: 2
size_low_pri: 0
size_low_priority: 0
subnet_id:
# Toolkit configuration [Required]
toolkit:

Просмотреть файл

@ -24,7 +24,7 @@ docker:
# password:
# endpoint:
default:
# SSH keys used to create a user and connect to a server.
# The public key can either be the public key itself (ssh-rsa ...) or the path to the ssh key.
# ssh_pub_key: ~/.ssh/id_rsa.pub
# SSH keys used to create a user and connect to a server.
# The public key can either be the public key itself (ssh-rsa ...) or the path to the ssh key.
# ssh_pub_key: ~/.ssh/id_rsa.pub

Просмотреть файл

@ -29,7 +29,7 @@ def setup_parser(parser: argparse.ArgumentParser):
parser.add_argument('--no-wait', dest='wait', action='store_false')
parser.add_argument('--wait', dest='wait', action='store_true')
parser.set_defaults(wait=None, size=None, size_low_pri=None)
parser.set_defaults(wait=None, size=None, size_low_priority=None)
def execute(args: typing.NamedTuple):
@ -42,8 +42,8 @@ def execute(args: typing.NamedTuple):
cluster_conf.merge(file_config)
cluster_conf.merge(ClusterConfiguration(
cluster_id=args.cluster_id,
vm_count=args.size,
vm_low_pri_count=args.size_low_pri,
size=args.size,
size_low_priority=args.size_low_priority,
vm_size=args.vm_size,
subnet_id=args.subnet_id,
user_configuration=UserConfiguration(
@ -71,6 +71,7 @@ def execute(args: typing.NamedTuple):
else:
cluster_conf.user_configuration = None
cluster_conf.validate()
utils.print_cluster_conf(cluster_conf, wait)
with utils.Spinner():
# create spark cluster

Просмотреть файл

@ -3,6 +3,7 @@ import getpass
import sys
import threading
import time
import yaml
from subprocess import call
from typing import List
import azure.batch.models as batch_models
@ -420,14 +421,12 @@ def utc_to_local(utc_dt):
def print_cluster_conf(cluster_conf: ClusterConfiguration, wait: bool):
user_configuration = cluster_conf.user_configuration
log.info("-------------------------------------------")
log.info("cluster id: %s", cluster_conf.cluster_id)
log.info("cluster toolkit: %s %s", cluster_conf.toolkit.software, cluster_conf.toolkit.version)
log.info("cluster size: %s",
cluster_conf.vm_count + cluster_conf.vm_low_pri_count)
log.info("> dedicated: %s", cluster_conf.vm_count)
log.info("> low priority: %s", cluster_conf.vm_low_pri_count)
log.info("cluster toolkit: %s %s", cluster_conf.toolkit.software, cluster_conf.toolkit.version)
log.info("cluster size: %s", cluster_conf.size + cluster_conf.size_low_priority)
log.info("> dedicated: %s", cluster_conf.size)
log.info("> low priority: %s", cluster_conf.size_low_priority)
log.info("cluster vm size: %s", cluster_conf.vm_size)
log.info("custom scripts: %s", len(cluster_conf.custom_scripts) if cluster_conf.custom_scripts else 0)
log.info("subnet ID: %s", cluster_conf.subnet_id)

Просмотреть файл

@ -47,12 +47,11 @@ To finish setting up, you need to fill out your Azure Batch and Azure Storage se
Please note that if you use ssh keys and a have a non-standard ssh key file name or path, you will need to specify the location of your ssh public and private keys. To do so, set them as shown below:
```yaml
default:
# SSH keys used to create a user and connect to a server.
# The public key can either be the public key itself(ssh-rsa ...) or the path to the ssh key.
# The private key must be the path to the key.
ssh_pub_key: ~/.ssh/my-public-key.pub
ssh_priv_key: ~/.ssh/my-private-key
# SSH keys used to create a user and connect to a server.
# The public key can either be the public key itself(ssh-rsa ...) or the path to the ssh key.
# The private key must be the path to the key.
ssh_pub_key: ~/.ssh/my-public-key.pub
ssh_priv_key: ~/.ssh/my-private-key
```
0. Log into Azure

Просмотреть файл

@ -29,7 +29,7 @@ vm_size: standard_a2
# size: <number of dedicated nodes in the cluster, not that clusters must contain all dedicated or all low priority nodes>
size: 2
# size_low_pri: <number of low priority nodes in the cluster, mutually exclusive with size setting>
# size_low_priority: <number of low priority nodes in the cluster, mutually exclusive with size setting>
# username: <username for the linux user to be created> (optional)
username: spark

Просмотреть файл

@ -56,7 +56,7 @@ Jobs also require a definition of the cluster on which the Applications will run
```
_Please Note: For more information about Azure VM sizes, see [Azure Batch Pricing](https://azure.microsoft.com/en-us/pricing/details/batch/). And for more information about Docker repositories see [Docker](./12-docker-iamge.html)._
_The only required fields are vm_size and either size or size_low_pri, all other fields can be left blank or removed._
_The only required fields are vm_size and either size or size_low_priority, all other fields can be left blank or removed._
A Job definition may also include a default Spark Configuration. The following are the properties to define a Spark Configuration:
```yaml

Просмотреть файл

@ -1,10 +1,8 @@
aztk.models package
===================
aztk.models.models module
-------------------------
.. automodule:: aztk.models.models
.. automodule:: aztk.models
:members:
:undoc-members:
:show-inheritance:
:imported-members:

Просмотреть файл

@ -6,57 +6,50 @@ In `aztk/models` create a new file with the name of your model `my_model.py`
In `aztk/models/__init__.py` add `from .my_model import MyModel`
Create a new class `MyModel` that inherit `ConfigurationBase`
Create a new class `MyModel` that inherit `Modle`
```python
from aztk.internal import ConfigurationBase
from aztk.core.models import Model, fields
class MyModel(ConfigurationBase):
class MyModel(Model):
"""
MyModel is an sample model
Args:
input1 (str): This is the first input
"""
def __init__(self, input1: str):
self.input1 = input1
def validate(self):
input1 = fields.String()
def __validate__(self):
pass
```
### Available fields types
Check `aztk/core/models/fields.py` for the sources
* `Field`: Base field class
* `String`: Field that validate it is given a string
* `Integer`: Field that validate it is given a int
* `Float`: Field that validate it is given a float
* `Boolean`: Field that validate it is given a boolean
* `List`: Field that validate it is given a list and can also automatically convert entries to the given model type.
* `Model`: Field that map to another model. If passed a dict it will automatically try to convert to the Model type
* `Enum`: Field which value should be an enum. It will convert automatically to the enum if given the value.
## Add validation
The fields provide basic validation automatically. A field without a default will be marked as required.
In `def validate` do any kind of checks and raise a `InvalidModelError` if there is any problems with the values
### Validate required
To validate required attributes call the parent `_validate_required` method. Method takes a list of attributes which should not be None
To provide model wide validation implement a `__validate__` method and raise a `InvalidModelError` if there is any problems with the values
```python
def validate(self) -> bool:
self._validate_required(["input1"])
```
### Custom validation
```python
def validate(self) -> bool:
if "foo" in self.input1:
raise InvalidModelError("foo cannot be in input1")
def __validate__(self):
if 'secret' in self.input1:
raise InvalidModelError("Input1 contains secrets")
```
## Convert dict to model
When inheriting from `ConfigurationBase` it comes with a `from_dict` class method which allows to convert a dict to this class
It works great for simple case where values are simple types(str, int, etc). If however you need to process it you can override the `_from_dict` method.
** Important: Do not override the `from_dict` method as this one will handle error and display them nicely **
```python
@classmethod
def _from_dict(cls, args: dict):
if "input1" in args:
args["input1"] = MyInput1Model.from_dict(args["input1"])
return super()._from_dict(args)
```
When inheriting from `Model` it comes with a `from_dict` class method which allows to convert a dict to this class

Просмотреть файл

@ -192,7 +192,7 @@ max-nested-blocks=5
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
max-line-length=140
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$

316
tests/core/test_models.py Normal file
Просмотреть файл

@ -0,0 +1,316 @@
from enum import Enum
import yaml
import pytest
from aztk.core.models import Model, fields, ModelMergeStrategy, ListMergeStrategy
from aztk.error import InvalidModelFieldError, AztkAttributeError
# pylint: disable=C1801
class UserState(Enum):
Creating = "creating"
Ready = "ready"
Deleting = "deleting"
class UserInfo(Model):
name = fields.String()
age = fields.Integer()
class User(Model):
info = fields.Model(UserInfo)
enabled = fields.Boolean(default=True)
state = fields.Enum(UserState, default=UserState.Ready)
def test_models():
user = User(
info=UserInfo(
name="Highlander",
age=800,
),
enabled=False,
state=UserState.Creating,
)
assert user.info.name == "Highlander"
assert user.info.age == 800
assert user.enabled is False
assert user.state == UserState.Creating
def test_inherited_models():
class ServiceUser(User):
service = fields.String()
user = ServiceUser(
info=dict(
name="Bob",
age=59,
),
enabled=False,
service="bus",
)
user.validate()
assert user.info.name == "Bob"
assert user.info.age == 59
assert user.enabled is False
assert user.state == UserState.Ready
assert user.service == "bus"
def test_raise_error_if_extra_parameters():
class SimpleNameModel(Model):
name = fields.String()
with pytest.raises(AztkAttributeError, match="SimpleNameModel doesn't have an attribute called abc"):
SimpleNameModel(name="foo", abc="123")
def test_enum_invalid_type_raise_error():
class SimpleStateModel(Model):
state = fields.Enum(UserState)
with pytest.raises(
InvalidModelFieldError,
match="SimpleStateModel state unknown is not a valid option. Use one of \\['creating', 'ready', 'deleting'\\]"):
obj = SimpleStateModel(state="unknown")
obj.validate()
def test_enum_parse_string():
class SimpleStateModel(Model):
state = fields.Enum(UserState)
obj = SimpleStateModel(state="creating")
obj.validate()
assert obj.state == UserState.Creating
def test_convert_nested_dict_to_model():
user = User(
info=dict(
name="Highlander",
age=800,
),
enabled=False,
state="deleting",
)
assert isinstance(user.info, UserInfo)
assert user.info.name == "Highlander"
assert user.info.age == 800
assert user.enabled is False
assert user.state == UserState.Deleting
def test_raise_error_if_missing_required_field():
class SimpleRequiredModel(Model):
name = fields.String()
missing = SimpleRequiredModel()
with pytest.raises(InvalidModelFieldError, match="SimpleRequiredModel name is required"):
missing.validate()
def test_raise_error_if_string_field_invalid_type():
class SimpleStringModel(Model):
name = fields.String()
missing = SimpleStringModel(name=123)
with pytest.raises(InvalidModelFieldError, match="SimpleStringModel name 123 should be a string"):
missing.validate()
def test_raise_error_if_int_field_invalid_type():
class SimpleIntegerModel(Model):
age = fields.Integer()
missing = SimpleIntegerModel(age='123')
with pytest.raises(InvalidModelFieldError, match="SimpleIntegerModel age 123 should be an integer"):
missing.validate()
def test_raise_error_if_bool_field_invalid_type():
class SimpleBoolModel(Model):
enabled = fields.Boolean()
missing = SimpleBoolModel(enabled="false")
with pytest.raises(InvalidModelFieldError, match="SimpleBoolModel enabled false should be a boolean"):
missing.validate()
def test_merge_with_default_value():
class SimpleMergeModel(Model):
name = fields.String()
enabled = fields.Boolean(default=True)
record1 = SimpleMergeModel(enabled=False)
assert record1.enabled is False
record2 = SimpleMergeModel(name="foo")
assert record2.enabled is True
record1.merge(record2)
assert record1.name == 'foo'
assert record1.enabled is False
def test_merge_nested_model_merge_strategy():
class ComplexModel(Model):
model_id = fields.String()
info = fields.Model(UserInfo, merge_strategy=ModelMergeStrategy.Merge)
obj1 = ComplexModel(
info=dict(
name="John",
age=29,
)
)
obj2 = ComplexModel(
info=dict(
age=38,
)
)
assert obj1.info.age == 29
assert obj2.info.age == 38
obj1.merge(obj2)
assert obj1.info.name == "John"
assert obj1.info.age == 38
def test_merge_nested_model_override_strategy():
class ComplexModel(Model):
model_id = fields.String()
info = fields.Model(UserInfo, merge_strategy=ModelMergeStrategy.Override)
obj1 = ComplexModel(
info=dict(
name="John",
age=29,
)
)
obj2 = ComplexModel(
info=dict(
age=38,
)
)
assert obj1.info.age == 29
assert obj2.info.age == 38
obj1.merge(obj2)
assert obj1.info.name is None
assert obj1.info.age == 38
def test_list_field_convert_model_correctly():
class UserList(Model):
infos = fields.List(UserInfo)
obj = UserList(
infos=[
dict(
name="John",
age=29,
)
]
)
obj.validate()
assert len(obj.infos) == 1
assert isinstance(obj.infos[0], UserInfo)
assert obj.infos[0].name == "John"
assert obj.infos[0].age == 29
def test_list_field_is_never_required():
class UserList(Model):
infos = fields.List(UserInfo)
obj = UserList()
obj.validate()
assert isinstance(obj.infos, (list,))
assert len(obj.infos) == 0
infos = obj.infos
infos.append(UserInfo())
assert len(obj.infos) == 1
obj2 = UserList(infos=None)
assert isinstance(obj2.infos, (list,))
assert len(obj2.infos) == 0
def test_list_field_ignore_none_entries():
class UserList(Model):
infos = fields.List(UserInfo)
obj = UserList(infos=[None, None])
obj.validate()
assert isinstance(obj.infos, (list,))
assert len(obj.infos) == 0
def test_merge_nested_model_append_strategy():
class UserList(Model):
infos = fields.List(UserInfo, merge_strategy=ListMergeStrategy.Append)
obj1 = UserList(
infos=[
dict(
name="John",
age=29,
)
]
)
obj2 = UserList(
infos=[dict(
name="Frank",
age=38,
)]
)
assert len(obj1.infos) == 1
assert len(obj2.infos) == 1
assert obj1.infos[0].name == "John"
assert obj1.infos[0].age == 29
assert obj2.infos[0].name == "Frank"
assert obj2.infos[0].age == 38
obj1.merge(obj2)
assert len(obj1.infos) == 2
assert obj1.infos[0].name == "John"
assert obj1.infos[0].age == 29
assert obj1.infos[1].name == "Frank"
assert obj1.infos[1].age == 38
def test_serialize_simple_model_to_yaml():
info = UserInfo(name="John", age=29)
output = yaml.dump(info)
assert output == "!!python/object:test_models.UserInfo {age: 29, name: John}\n"
info_parsed = yaml.load(output)
assert isinstance(info_parsed, UserInfo)
assert info_parsed.name == "John"
assert info_parsed.age == 29
def test_serialize_nested_model_to_yaml():
user = User(
info=dict(name="John", age=29),
enabled=True,
state=UserState.Deleting,
)
output = yaml.dump(user)
assert output == "!!python/object:test_models.User\nenabled: true\ninfo: {age: 29, name: John}\nstate: deleting\n"
user_parsed = yaml.load(output)
assert isinstance(user_parsed, User)
assert isinstance(user_parsed.info, UserInfo)
assert user_parsed.info.name == "John"
assert user_parsed.info.age == 29
assert user_parsed.state == UserState.Deleting
assert user_parsed.enabled is True

Просмотреть файл

@ -8,9 +8,8 @@ dir_path = os.path.dirname(os.path.realpath(__file__))
fake_plugin_dir = os.path.join(dir_path, "fake_plugins")
class RequiredArgPlugin(PluginConfiguration):
def __init__(self, req_arg):
super().__init__(name="required-arg")
def RequiredArgPlugin(req_arg):
return PluginConfiguration(name="required-arg")
def test_missing_plugin():

Просмотреть файл

@ -1,6 +1,6 @@
import pytest
from aztk.error import AztkError
from aztk.error import AztkError, AztkAttributeError
from aztk.models.plugins.internal import PluginReference, PluginTarget, PluginTargetRole
@ -19,7 +19,7 @@ def test_from_dict():
def test_from_dict_invalid_param():
with pytest.raises(AztkError):
with pytest.raises(AztkAttributeError):
PluginReference.from_dict(dict(
name2="invalid"
))

Просмотреть файл

@ -0,0 +1,61 @@
import pytest
from aztk.models import ClusterConfiguration, Toolkit, UserConfiguration
from aztk.spark.models.plugins import JupyterPlugin, HDFSPlugin
def test_vm_count_deprecated():
with pytest.warns(DeprecationWarning):
config = ClusterConfiguration(vm_count=3)
assert config.size == 3
with pytest.warns(DeprecationWarning):
config = ClusterConfiguration(vm_low_pri_count=10)
assert config.size_low_priority == 10
def test_size_none():
config = ClusterConfiguration(size=None, size_low_priority=2)
assert config.size == 0
assert config.size_low_priority == 2
def test_size_low_priority_none():
config = ClusterConfiguration(size=1, size_low_priority=None)
assert config.size == 1
assert config.size_low_priority == 0
def test_cluster_configuration():
data = {
'toolkit': {
'software': 'spark',
'version': '2.3.0',
'environment': 'anaconda'
},
'vm_size': 'standard_a2',
'size': 2,
'size_low_priority': 3,
'subnet_id': '/subscriptions/21abd678-18c5-4660-9fdd-8c5ba6b6fe1f/resourceGroups/abc/providers/Microsoft.Network/virtualNetworks/prodtest5vnet/subnets/default',
'plugins': [
JupyterPlugin(),
HDFSPlugin(),
],
'user_configuration': {'username': 'spark'}
}
config = ClusterConfiguration.from_dict(data)
assert isinstance(config.toolkit, Toolkit)
assert config.toolkit.software == 'spark'
assert config.toolkit.version == '2.3.0'
assert config.toolkit.environment == 'anaconda'
assert config.size == 2
assert config.size_low_priority == 3
assert config.vm_size == 'standard_a2'
assert config.subnet_id == '/subscriptions/21abd678-18c5-4660-9fdd-8c5ba6b6fe1f/resourceGroups/abc/providers/Microsoft.Network/virtualNetworks/prodtest5vnet/subnets/default'
assert isinstance(config.user_configuration, UserConfiguration)
assert config.user_configuration.username == 'spark'
assert len(config.plugins) == 2
assert config.plugins[0].name == 'jupyter'
assert config.plugins[1].name == 'hdfs'