зеркало из https://github.com/Azure/aztk.git
Feature: New Models design with auto validation, default and merging (#543)
This commit is contained in:
Родитель
f6735cc6dd
Коммит
02f336b0a0
|
@ -49,3 +49,6 @@ tmp/
|
|||
|
||||
# Built docs
|
||||
docs/_build/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
|
|
@ -3,4 +3,4 @@ based_on_style = pep8
|
|||
spaces_before_comment = 4
|
||||
split_before_logical_operator = true
|
||||
indent_width = 4
|
||||
column_limit = 120
|
||||
column_limit = 140
|
||||
|
|
|
@ -16,5 +16,6 @@
|
|||
"--style=.style.yapf"
|
||||
],
|
||||
"python.venvPath": "${workspaceFolder}/.venv/",
|
||||
"python.pythonPath": "${workspaceFolder}/.venv/Scripts/python.exe"
|
||||
"python.pythonPath": "${workspaceFolder}/.venv/Scripts/python.exe",
|
||||
"python.unitTest.pyTestEnabled": true
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ class Client:
|
|||
network_conf = batch_models.NetworkConfiguration(
|
||||
subnet_id=cluster_conf.subnet_id)
|
||||
auto_scale_formula = "$TargetDedicatedNodes={0}; $TargetLowPriorityNodes={1}".format(
|
||||
cluster_conf.vm_count, cluster_conf.vm_low_pri_count)
|
||||
cluster_conf.size, cluster_conf.size_low_priority)
|
||||
|
||||
# Confiure the pool
|
||||
pool = batch_models.PoolAddParameter(
|
||||
|
@ -110,7 +110,7 @@ class Client:
|
|||
batch_models.MetadataItem(
|
||||
name=constants.AZTK_SOFTWARE_METADATA_KEY, value=software_metadata_key),
|
||||
batch_models.MetadataItem(
|
||||
name=constants.AZTK_MODE_METADATA_KEY, value=constants.AZTK_CLUSTER_MODE_METADATA)
|
||||
name=constants.AZTK_MODE_METADATA_KEY, value=constants.AZTK_CLUSTER_MODE_METADATA)
|
||||
])
|
||||
|
||||
# Create the pool + create user for the pool
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .model import Model
|
||||
from .fields import String, Integer, Boolean, Float, List, ModelMergeStrategy, ListMergeStrategy
|
|
@ -0,0 +1,241 @@
|
|||
import collections
|
||||
import enum
|
||||
|
||||
from aztk.error import InvalidModelFieldError
|
||||
from . import validators as aztk_validators
|
||||
|
||||
class ModelMergeStrategy(enum.Enum):
|
||||
Override = 1
|
||||
"""
|
||||
Override the value with the other value
|
||||
"""
|
||||
Merge = 2
|
||||
"""
|
||||
Try to merge value nested
|
||||
"""
|
||||
|
||||
class ListMergeStrategy(enum.Enum):
|
||||
Replace = 1
|
||||
"""
|
||||
Override the value with the other value
|
||||
"""
|
||||
Append = 2
|
||||
"""
|
||||
Append all the values of the new list
|
||||
"""
|
||||
|
||||
# pylint: disable=W0212
|
||||
class Field:
|
||||
"""
|
||||
Base class for all model fields
|
||||
"""
|
||||
def __init__(self, *validators, **kwargs):
|
||||
self.default = kwargs.get('default')
|
||||
self.required = 'default' not in kwargs
|
||||
self.validators = []
|
||||
|
||||
if self.required:
|
||||
self.validators.append(aztk_validators.Required())
|
||||
|
||||
self.validators.extend(validators)
|
||||
|
||||
choices = kwargs.get('choices')
|
||||
if choices:
|
||||
self.validators.append(aztk_validators.In(choices))
|
||||
|
||||
def validate(self, value):
|
||||
for validator in self.validators:
|
||||
validator(value)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is not None:
|
||||
value = instance._data.get(self)
|
||||
if value is None:
|
||||
return instance._defaults.setdefault(self, self._default(instance))
|
||||
return value
|
||||
|
||||
return self
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance._data[self] = value
|
||||
|
||||
def merge(self, instance, value):
|
||||
"""
|
||||
Method called when merging 2 model together.
|
||||
This is overriden in some of the fields where merge can be handled differently
|
||||
"""
|
||||
if value is not None:
|
||||
instance._data[self] = value
|
||||
|
||||
def serialize(self, instance):
|
||||
return self.__get__(instance, None)
|
||||
|
||||
def _default(self, model):
|
||||
if callable(self.default):
|
||||
return self.__call_default(model)
|
||||
|
||||
return self.default
|
||||
|
||||
def __call_default(self, *args):
|
||||
try:
|
||||
return self.default()
|
||||
except TypeError as error:
|
||||
try:
|
||||
return self.default(*args)
|
||||
except TypeError:
|
||||
raise error
|
||||
|
||||
|
||||
class String(Field):
|
||||
"""
|
||||
Model String field
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(aztk_validators.String(), *args, **kwargs)
|
||||
|
||||
|
||||
class Integer(Field):
|
||||
"""
|
||||
Model Integer field
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(aztk_validators.Integer(), *args, **kwargs)
|
||||
|
||||
|
||||
class Float(Field):
|
||||
"""
|
||||
Model Float field
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(aztk_validators.Float(), *args, **kwargs)
|
||||
|
||||
|
||||
class Boolean(Field):
|
||||
"""
|
||||
Model Boolean field
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(aztk_validators.Boolean(), *args, **kwargs)
|
||||
|
||||
|
||||
class List(Field):
|
||||
"""
|
||||
Field that should be a list
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
self.model = model
|
||||
kwargs.setdefault('default', list)
|
||||
self.merge_strategy = kwargs.get('merge_strategy', ListMergeStrategy.Append)
|
||||
self.skip_none = kwargs.get('skip_none', True)
|
||||
|
||||
super().__init__(
|
||||
aztk_validators.List(*kwargs.get('inner_validators', [])), **kwargs)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if isinstance(value, collections.MutableSequence):
|
||||
value = self._resolve(value)
|
||||
if value is None:
|
||||
value = []
|
||||
super().__set__(instance, value)
|
||||
|
||||
def _resolve(self, value):
|
||||
result = []
|
||||
for item in value:
|
||||
if item is None and self.skip_none: # Skip none values
|
||||
continue
|
||||
|
||||
if self.model and isinstance(item, collections.MutableMapping):
|
||||
item = self.model(**item)
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
def merge(self, instance, value):
|
||||
if value is None:
|
||||
value = []
|
||||
|
||||
if self.merge_strategy == ListMergeStrategy.Append:
|
||||
current = instance._data.get(self)
|
||||
if current is None:
|
||||
current = []
|
||||
value = current + value
|
||||
|
||||
instance._data[self] = value
|
||||
|
||||
def serialize(self, instance):
|
||||
items = super().serialize(instance)
|
||||
output = []
|
||||
if items is not None:
|
||||
for item in items:
|
||||
if hasattr(item, 'to_dict'):
|
||||
output.append(item.to_dict())
|
||||
else:
|
||||
output.append(item)
|
||||
return output
|
||||
|
||||
class Model(Field):
|
||||
"""
|
||||
Field is another model
|
||||
|
||||
Args:
|
||||
model (aztk.core.models.Model): Model object that field should be
|
||||
merge_strategy (ModelMergeStrategy): When merging models how should the nested model be merged.
|
||||
Default: `ModelMergeStrategy.merge`
|
||||
"""
|
||||
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(aztk_validators.Model(model), *args, **kwargs)
|
||||
|
||||
self.model = model
|
||||
self.merge_strategy = kwargs.get('merge_strategy', ModelMergeStrategy.Merge)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if isinstance(value, collections.MutableMapping):
|
||||
value = self.model(**value)
|
||||
|
||||
super().__set__(instance, value)
|
||||
|
||||
def merge(self, instance, value):
|
||||
if self.merge_strategy == ModelMergeStrategy.Merge:
|
||||
current = instance._data.get(self)
|
||||
if current is not None:
|
||||
current.merge(value)
|
||||
value = current
|
||||
|
||||
instance._data[self] = value
|
||||
|
||||
def serialize(self, instance):
|
||||
val = super().serialize(instance)
|
||||
if val is not None:
|
||||
return val.to_dict()
|
||||
else:
|
||||
return None
|
||||
|
||||
class Enum(Field):
|
||||
"""
|
||||
Field that should be an enum
|
||||
"""
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(aztk_validators.InstanceOf(model), *args, **kwargs)
|
||||
|
||||
self.model = model
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if value is not None and not isinstance(value, self.model):
|
||||
try:
|
||||
value = self.model(value)
|
||||
except ValueError:
|
||||
available = [e.value for e in self.model]
|
||||
raise InvalidModelFieldError("{0} is not a valid option. Use one of {1}".format(value, available))
|
||||
super().__set__(instance, value)
|
||||
|
||||
|
||||
def serialize(self, instance):
|
||||
val = super().serialize(instance)
|
||||
if val is not None:
|
||||
return val.value
|
||||
else:
|
||||
return None
|
|
@ -0,0 +1,123 @@
|
|||
import yaml
|
||||
|
||||
from aztk.error import InvalidModelError, InvalidModelFieldError, AztkError, AztkAttributeError
|
||||
from aztk.core.models import fields
|
||||
|
||||
|
||||
# pylint: disable=W0212,no-member
|
||||
class ModelMeta(type):
|
||||
"""
|
||||
Model Meta class. This takes all the class definition and build the attributes form all the fields definitions.
|
||||
"""
|
||||
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
attrs['_fields'] = {}
|
||||
|
||||
for base in bases:
|
||||
if hasattr(base, '_fields'):
|
||||
for k, v in base._fields.items():
|
||||
attrs['_fields'][k] = v
|
||||
for k, v in base.__dict__.items():
|
||||
if isinstance(v, fields.Field):
|
||||
attrs['_fields'][k] = v
|
||||
|
||||
for k, v in attrs.items():
|
||||
if isinstance(v, fields.Field):
|
||||
attrs['_fields'][k] = v
|
||||
|
||||
return super().__new__(mcs, name, bases, attrs)
|
||||
|
||||
|
||||
class Model(metaclass=ModelMeta):
|
||||
"""
|
||||
Base class for all aztk models
|
||||
|
||||
To implement model wide validation implement `__validate__` method
|
||||
"""
|
||||
|
||||
def __new__(cls, *_args, **_kwargs):
|
||||
model = super().__new__(cls)
|
||||
model._data = {}
|
||||
model._defaults = {}
|
||||
return model
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._update(kwargs)
|
||||
|
||||
def __getitem__(self, k):
|
||||
if k not in self._fields:
|
||||
raise AztkAttributeError("{0} doesn't have an attribute called {1}".format(self.__class__.__name__, k))
|
||||
|
||||
return getattr(self, k)
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
if k not in self._fields:
|
||||
raise AztkAttributeError("{0} doesn't have an attribute called {1}".format(self.__class__.__name__, k))
|
||||
try:
|
||||
setattr(self, k, v)
|
||||
except InvalidModelFieldError as e:
|
||||
self._process_field_error(e, k)
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
For pickle serialization. This return the state of the model
|
||||
"""
|
||||
return self.to_dict()
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""
|
||||
For pickle serialization. This update the current model with the given state
|
||||
"""
|
||||
self._update(state)
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Validate the entire model
|
||||
"""
|
||||
|
||||
for name, field in self._fields.items():
|
||||
try:
|
||||
field.validate(getattr(self, name))
|
||||
except InvalidModelFieldError as e:
|
||||
self._process_field_error(e, name)
|
||||
except InvalidModelError as e:
|
||||
e.model = self
|
||||
raise e
|
||||
|
||||
if hasattr(self, '__validate__'):
|
||||
self.__validate__()
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
raise AztkError("Cannot merge {0} as is it not an instance of {1}".format(other, self.__class__.__name__))
|
||||
|
||||
for field in other._fields.values():
|
||||
if field in other._data:
|
||||
field.merge(self, other._data[field])
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, val: dict):
|
||||
return cls(**val)
|
||||
|
||||
def to_dict(self):
|
||||
output = dict()
|
||||
for name, field in self._fields.items():
|
||||
output[name] = field.serialize(self)
|
||||
return output
|
||||
|
||||
def __str__(self):
|
||||
return yaml.dump(self.to_dict(), default_flow_style=False)
|
||||
|
||||
def _update(self, values):
|
||||
for k, v in values.items():
|
||||
self[k] = v
|
||||
|
||||
def _process_field_error(self, e: InvalidModelFieldError, field: str):
|
||||
if not e.field:
|
||||
e.field = field
|
||||
|
||||
if not e.model:
|
||||
e.model = self
|
||||
raise e
|
|
@ -0,0 +1,149 @@
|
|||
import collections
|
||||
|
||||
from aztk.error import InvalidModelFieldError
|
||||
|
||||
|
||||
class Validator:
|
||||
"""
|
||||
Base class for a validator.
|
||||
To write your validator extend this class and implement the validate method.
|
||||
To raise an error raise InvalidModelFieldError
|
||||
"""
|
||||
def __call__(self, value):
|
||||
self.validate(value)
|
||||
|
||||
def validate(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
||||
class Required(Validator):
|
||||
"""
|
||||
Validate the field valiue is not `None`
|
||||
"""
|
||||
|
||||
def validate(self, value):
|
||||
if value is None:
|
||||
raise InvalidModelFieldError('is required')
|
||||
|
||||
|
||||
class String(Validator):
|
||||
"""
|
||||
Validate the value of the field is a `str`
|
||||
"""
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise InvalidModelFieldError('{0} should be a string'.format(value))
|
||||
|
||||
|
||||
class Integer(Validator):
|
||||
"""
|
||||
Validate the value of the field is a `int`
|
||||
"""
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, int):
|
||||
raise InvalidModelFieldError('{0} should be an integer'.format(value))
|
||||
|
||||
|
||||
class Float(Validator):
|
||||
"""
|
||||
Validate the value of the field is a `float`
|
||||
"""
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, float):
|
||||
raise InvalidModelFieldError('{0} should be a float'.format(value))
|
||||
|
||||
|
||||
class Boolean(Validator):
|
||||
"""This validator forces fields values to be an instance of `bool`."""
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, bool):
|
||||
raise InvalidModelFieldError('{0} should be a boolean'.format(value))
|
||||
|
||||
|
||||
|
||||
class In(Validator):
|
||||
"""
|
||||
Validate the field value is in the list of allowed choices
|
||||
"""
|
||||
|
||||
def __init__(self, choices):
|
||||
self.choices = choices
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if value not in self.choices:
|
||||
raise InvalidModelFieldError('{0} should be in {1}'.format(value, self.choices))
|
||||
|
||||
class InstanceOf(Validator):
|
||||
"""
|
||||
Check if the field is an instance of the given type
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
self.type = cls
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, self.type):
|
||||
raise InvalidModelFieldError(
|
||||
"should be an instance of '{}'".format(self.type.__name__))
|
||||
|
||||
|
||||
class Model(Validator):
|
||||
"""
|
||||
Validate the field is a model
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, self.model):
|
||||
raise InvalidModelFieldError(
|
||||
"should be an instance of '{}'".format(self.model.__name__))
|
||||
|
||||
value.validate()
|
||||
|
||||
|
||||
class List(Validator):
|
||||
"""
|
||||
Validate the given item is a list
|
||||
"""
|
||||
|
||||
def __init__(self, *validators):
|
||||
self.validators = validators
|
||||
|
||||
def validate(self, value):
|
||||
if not value:
|
||||
return
|
||||
|
||||
if not isinstance(value, collections.MutableSequence):
|
||||
raise InvalidModelFieldError('should be a list')
|
||||
|
||||
for i in value:
|
||||
for validator in self.validators:
|
||||
validator(i)
|
|
@ -7,6 +7,10 @@ All error should inherit from `AztkError`
|
|||
class AztkError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AztkAttributeError(AztkError):
|
||||
pass
|
||||
|
||||
class ClusterNotReadyError(AztkError):
|
||||
pass
|
||||
|
||||
|
@ -17,7 +21,15 @@ class InvalidPluginConfigurationError(AztkError):
|
|||
pass
|
||||
|
||||
class InvalidModelError(AztkError):
|
||||
pass
|
||||
def __init__(self, message: str, model=None):
|
||||
super().__init__()
|
||||
self.message = message
|
||||
self.model = model
|
||||
|
||||
def __str__(self):
|
||||
model_name = self.model and self.model.__class__.__name__
|
||||
return "{model} {message}".format(model=model_name, message=self.message)
|
||||
|
||||
|
||||
class MissingRequiredAttributeError(InvalidModelError):
|
||||
pass
|
||||
|
@ -27,3 +39,12 @@ class InvalidCustomScriptError(InvalidModelError):
|
|||
|
||||
class InvalidPluginReferenceError(InvalidModelError):
|
||||
pass
|
||||
|
||||
class InvalidModelFieldError(InvalidModelError):
|
||||
def __init__(self, message: str, model=None, field=None):
|
||||
super().__init__(message, model)
|
||||
self.field = field
|
||||
|
||||
def __str__(self):
|
||||
model_name = self.model and self.model.__class__.__name__
|
||||
return "{model} {field} {message}".format(model=model_name, field=self.field, message=self.message)
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import io
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
import azure.common
|
||||
from .node_data import NodeData
|
||||
import yaml
|
||||
|
||||
from aztk.models import ClusterConfiguration
|
||||
|
||||
from .blob_data import BlobData
|
||||
from .node_data import NodeData
|
||||
|
||||
|
||||
class ClusterData:
|
||||
|
|
|
@ -1,2 +1,18 @@
|
|||
from .toolkit import Toolkit, TOOLKIT_MAP
|
||||
from .models import *
|
||||
from .cluster_configuration import ClusterConfiguration
|
||||
from .custom_script import CustomScript
|
||||
from .file_share import FileShare
|
||||
from .toolkit import TOOLKIT_MAP, Toolkit
|
||||
from .user_configuration import UserConfiguration
|
||||
from .secrets_configuration import (
|
||||
SecretsConfiguration,
|
||||
ServicePrincipalConfiguration,
|
||||
SharedKeyConfiguration,
|
||||
DockerConfiguration,
|
||||
)
|
||||
from .file import File
|
||||
from .remote_login import RemoteLogin
|
||||
from .ssh_log import SSHLog
|
||||
from .vm_image import VmImage
|
||||
from .software import Software
|
||||
from .cluster import Cluster
|
||||
from .plugins import *
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import azure.batch.models as batch_models
|
||||
|
||||
class Cluster:
|
||||
def __init__(self,
|
||||
pool: batch_models.CloudPool,
|
||||
nodes: batch_models.ComputeNodePaged = None):
|
||||
self.id = pool.id
|
||||
self.pool = pool
|
||||
self.nodes = nodes
|
||||
self.vm_size = pool.vm_size
|
||||
if pool.state.value is batch_models.PoolState.active:
|
||||
self.visible_state = pool.allocation_state.value
|
||||
else:
|
||||
self.visible_state = pool.state.value
|
||||
self.total_current_nodes = pool.current_dedicated_nodes + \
|
||||
pool.current_low_priority_nodes
|
||||
self.total_target_nodes = pool.target_dedicated_nodes + \
|
||||
pool.target_low_priority_nodes
|
||||
self.current_dedicated_nodes = pool.current_dedicated_nodes
|
||||
self.current_low_pri_nodes = pool.current_low_priority_nodes
|
||||
self.target_dedicated_nodes = pool.target_dedicated_nodes
|
||||
self.target_low_pri_nodes = pool.target_low_priority_nodes
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import aztk.error as error
|
||||
from aztk.core.models import Model, fields
|
||||
from aztk.utils import deprecate, helpers
|
||||
|
||||
from .custom_script import CustomScript
|
||||
from .file_share import FileShare
|
||||
from .plugins import PluginConfiguration
|
||||
from .toolkit import Toolkit
|
||||
from .user_configuration import UserConfiguration
|
||||
|
||||
|
||||
class ClusterConfiguration(Model):
|
||||
"""
|
||||
Cluster configuration model
|
||||
|
||||
Args:
|
||||
cluster_id (str): Id of the Aztk cluster
|
||||
toolkit (aztk.models.Toolkit): Toolkit to be used for this cluster
|
||||
size (int): Number of dedicated nodes for this cluster
|
||||
size_low_priority (int): Number of low priority nodes for this cluster
|
||||
vm_size (int): Azure Vm size to be used for each node
|
||||
subnet_id (str): Full resource id of the subnet to be used(Required for mixed mode clusters)
|
||||
plugins (List[aztk.models.plugins.PluginConfiguration]): List of plugins to be used
|
||||
file_shares (List[aztk.models.FileShare]): List of File shares to be used
|
||||
user_configuration (aztk.models.UserConfiguration): Configuration of the user to
|
||||
be created on the master node to ssh into.
|
||||
"""
|
||||
|
||||
cluster_id = fields.String()
|
||||
toolkit = fields.Model(Toolkit)
|
||||
size = fields.Integer(default=0)
|
||||
size_low_priority = fields.Integer(default=0)
|
||||
vm_size = fields.String()
|
||||
|
||||
subnet_id = fields.String(default=None)
|
||||
plugins = fields.List(PluginConfiguration)
|
||||
custom_scripts = fields.List(CustomScript)
|
||||
file_shares = fields.List(FileShare)
|
||||
user_configuration = fields.Model(UserConfiguration, default=None)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'vm_count' in kwargs:
|
||||
deprecate("vm_count is deprecated for ClusterConfiguration please use size instead")
|
||||
kwargs['size'] = kwargs.pop('vm_count')
|
||||
|
||||
if 'vm_low_pri_count' in kwargs:
|
||||
deprecate("vm_low_pri_count is deprecated for ClusterConfiguration please use size_low_priority instead")
|
||||
kwargs['size_low_priority'] = kwargs.pop('vm_low_pri_count')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def mixed_mode(self) -> bool:
|
||||
"""
|
||||
Return:
|
||||
if the pool is using mixed mode(Both dedicated and low priority nodes)
|
||||
"""
|
||||
return self.size > 0 and self.size_low_priority > 0
|
||||
|
||||
|
||||
def gpu_enabled(self):
|
||||
return helpers.is_gpu_enabled(self.vm_size)
|
||||
|
||||
def get_docker_repo(self):
|
||||
return self.toolkit.get_docker_repo(self.gpu_enabled())
|
||||
|
||||
def __validate__(self) -> bool:
|
||||
if self.size == 0 and self.size_low_priority == 0:
|
||||
raise error.InvalidModelError(
|
||||
"Please supply a valid (greater than 0) size or size_low_priority value either in the cluster.yaml configuration file or with a parameter (--size or --size-low-pri)"
|
||||
)
|
||||
|
||||
if self.vm_size is None:
|
||||
raise error.InvalidModelError(
|
||||
"Please supply a vm_size in either the cluster.yaml configuration file or with a parameter (--vm-size)"
|
||||
)
|
||||
|
||||
if self.mixed_mode() and not self.subnet_id:
|
||||
raise error.InvalidModelError(
|
||||
"You must configure a VNET to use AZTK in mixed mode (dedicated and low priority nodes). Set the VNET's subnet_id in your cluster.yaml."
|
||||
)
|
||||
|
||||
if self.custom_scripts:
|
||||
deprecate("Custom scripts are DEPRECATED and will be removed in 0.8.0. Use plugins instead See https://aztk.readthedocs.io/en/v0.7.0/15-plugins.html")
|
|
@ -0,0 +1,6 @@
|
|||
from aztk.core.models import Model, fields
|
||||
|
||||
class CustomScript(Model):
|
||||
name = fields.String()
|
||||
script = fields.String()
|
||||
run_on = fields.String()
|
|
@ -0,0 +1,6 @@
|
|||
import io
|
||||
|
||||
class File:
|
||||
def __init__(self, name: str, payload: io.StringIO):
|
||||
self.name = name
|
||||
self.payload = payload
|
|
@ -0,0 +1,7 @@
|
|||
from aztk.core.models import Model, fields
|
||||
|
||||
class FileShare(Model):
|
||||
storage_account_name = fields.String()
|
||||
storage_account_key = fields.String()
|
||||
file_share_path = fields.String()
|
||||
mount_path = fields.String()
|
|
@ -1,312 +0,0 @@
|
|||
import io
|
||||
from typing import List
|
||||
import azure.batch.models as batch_models
|
||||
from aztk import error
|
||||
from aztk.utils import helpers, deprecate
|
||||
from aztk.models.plugins import PluginConfiguration
|
||||
from aztk.internal import ConfigurationBase
|
||||
from .toolkit import Toolkit
|
||||
|
||||
|
||||
class FileShare:
|
||||
def __init__(self,
|
||||
storage_account_name: str = None,
|
||||
storage_account_key: str = None,
|
||||
file_share_path: str = None,
|
||||
mount_path: str = None):
|
||||
self.storage_account_name = storage_account_name
|
||||
self.storage_account_key = storage_account_key
|
||||
self.file_share_path = file_share_path
|
||||
self.mount_path = mount_path
|
||||
|
||||
class File:
|
||||
def __init__(self, name: str, payload: io.StringIO):
|
||||
self.name = name
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class CustomScript:
|
||||
def __init__(self, name: str = None, script=None, run_on=None):
|
||||
self.name = name
|
||||
self.script = script
|
||||
self.run_on = run_on
|
||||
|
||||
|
||||
class UserConfiguration(ConfigurationBase):
|
||||
def __init__(self,
|
||||
username: str,
|
||||
ssh_key: str = None,
|
||||
password: str = None):
|
||||
self.username = username
|
||||
self.ssh_key = ssh_key
|
||||
self.password = password
|
||||
|
||||
def merge(self, other):
|
||||
self._merge_attributes(other, [
|
||||
"username",
|
||||
"ssh_key",
|
||||
"password",
|
||||
])
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
|
||||
class ClusterConfiguration(ConfigurationBase):
|
||||
"""
|
||||
Cluster configuration model
|
||||
|
||||
Args:
|
||||
toolkit
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
toolkit: Toolkit = None,
|
||||
custom_scripts: List[CustomScript] = None,
|
||||
file_shares: List[FileShare] = None,
|
||||
cluster_id: str = None,
|
||||
vm_count=0,
|
||||
vm_low_pri_count=0,
|
||||
vm_size=None,
|
||||
subnet_id=None,
|
||||
plugins: List[PluginConfiguration] = None,
|
||||
user_configuration: UserConfiguration = None):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
self.custom_scripts = custom_scripts
|
||||
self.file_shares = file_shares
|
||||
self.cluster_id = cluster_id
|
||||
self.vm_count = vm_count
|
||||
self.vm_size = vm_size
|
||||
self.vm_low_pri_count = vm_low_pri_count
|
||||
self.subnet_id = subnet_id
|
||||
self.user_configuration = user_configuration
|
||||
self.plugins = plugins
|
||||
|
||||
def merge(self, other):
|
||||
"""
|
||||
Merge other cluster config into this one.
|
||||
:params other: ClusterConfiguration
|
||||
"""
|
||||
|
||||
self._merge_attributes(other, [
|
||||
"toolkit",
|
||||
"custom_scripts",
|
||||
"file_shares",
|
||||
"cluster_id",
|
||||
"vm_size",
|
||||
"subnet_id",
|
||||
"vm_count",
|
||||
"vm_low_pri_count",
|
||||
"plugins",
|
||||
])
|
||||
|
||||
if other.user_configuration:
|
||||
if self.user_configuration:
|
||||
self.user_configuration.merge(other.user_configuration)
|
||||
else:
|
||||
self.user_configuration = other.user_configuration
|
||||
|
||||
if self.plugins:
|
||||
for plugin in self.plugins:
|
||||
plugin.validate()
|
||||
|
||||
def mixed_mode(self) -> bool:
|
||||
return self.vm_count > 0 and self.vm_low_pri_count > 0
|
||||
|
||||
|
||||
def gpu_enabled(self):
|
||||
return helpers.is_gpu_enabled(self.vm_size)
|
||||
|
||||
def get_docker_repo(self):
|
||||
return self.toolkit.get_docker_repo(self.gpu_enabled())
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate the config at its current state.
|
||||
Raises: Error if invalid
|
||||
"""
|
||||
if self.toolkit is None:
|
||||
raise error.InvalidModelError(
|
||||
"Please supply a toolkit for the cluster")
|
||||
|
||||
self.toolkit.validate()
|
||||
|
||||
if self.cluster_id is None:
|
||||
raise error.AztkError(
|
||||
"Please supply an id for the cluster with a parameter (--id)")
|
||||
|
||||
if self.vm_count == 0 and self.vm_low_pri_count == 0:
|
||||
raise error.AztkError(
|
||||
"Please supply a valid (greater than 0) size or size_low_pri value either in the cluster.yaml configuration file or with a parameter (--size or --size-low-pri)"
|
||||
)
|
||||
|
||||
if self.vm_size is None:
|
||||
raise error.AztkError(
|
||||
"Please supply a vm_size in either the cluster.yaml configuration file or with a parameter (--vm-size)"
|
||||
)
|
||||
|
||||
if self.mixed_mode() and not self.subnet_id:
|
||||
raise error.AztkError(
|
||||
"You must configure a VNET to use AZTK in mixed mode (dedicated and low priority nodes). Set the VNET's subnet_id in your cluster.yaml."
|
||||
)
|
||||
|
||||
if self.custom_scripts:
|
||||
deprecate("Custom scripts are DEPRECATED and will be removed in 0.8.0. Use plugins instead See https://aztk.readthedocs.io/en/v0.7.0/15-plugins.html")
|
||||
|
||||
|
||||
class RemoteLogin:
|
||||
def __init__(self, ip_address, port):
|
||||
self.ip_address = ip_address
|
||||
self.port = port
|
||||
|
||||
|
||||
class ServicePrincipalConfiguration(ConfigurationBase):
|
||||
"""
|
||||
Container class for AAD authentication
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tenant_id: str = None,
|
||||
client_id: str = None,
|
||||
credential: str = None,
|
||||
batch_account_resource_id: str = None,
|
||||
storage_account_resource_id: str = None):
|
||||
self.tenant_id = tenant_id
|
||||
self.client_id = client_id
|
||||
self.credential = credential
|
||||
self.batch_account_resource_id = batch_account_resource_id
|
||||
self.storage_account_resource_id = storage_account_resource_id
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate the config at its current state.
|
||||
Raises: Error if invalid
|
||||
"""
|
||||
self._validate_required([
|
||||
"tenant_id",
|
||||
"client_id",
|
||||
"credential",
|
||||
"batch_account_resource_id",
|
||||
"storage_account_resource_id",
|
||||
])
|
||||
|
||||
|
||||
class SharedKeyConfiguration(ConfigurationBase):
|
||||
"""
|
||||
Container class for shared key authentication
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_account_name: str = None,
|
||||
batch_account_key: str = None,
|
||||
batch_service_url: str = None,
|
||||
storage_account_name: str = None,
|
||||
storage_account_key: str = None,
|
||||
storage_account_suffix: str = None):
|
||||
|
||||
self.batch_account_name = batch_account_name
|
||||
self.batch_account_key = batch_account_key
|
||||
self.batch_service_url = batch_service_url
|
||||
self.storage_account_name = storage_account_name
|
||||
self.storage_account_key = storage_account_key
|
||||
self.storage_account_suffix = storage_account_suffix
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate the config at its current state.
|
||||
Raises: Error if invalid
|
||||
"""
|
||||
self._validate_required([
|
||||
"batch_account_name",
|
||||
"batch_account_key",
|
||||
"batch_service_url",
|
||||
"storage_account_name",
|
||||
"storage_account_key",
|
||||
"storage_account_suffix",
|
||||
])
|
||||
|
||||
|
||||
class DockerConfiguration(ConfigurationBase):
|
||||
def __init__(self, endpoint=None, username=None, password=None):
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
|
||||
class SecretsConfiguration(ConfigurationBase):
|
||||
def __init__(self,
|
||||
service_principal=None,
|
||||
shared_key=None,
|
||||
docker=None,
|
||||
ssh_pub_key=None,
|
||||
ssh_priv_key=None):
|
||||
self.service_principal = service_principal
|
||||
self.shared_key = shared_key
|
||||
self.docker = docker
|
||||
|
||||
self.ssh_pub_key = ssh_pub_key
|
||||
self.ssh_priv_key = ssh_priv_key
|
||||
|
||||
def validate(self):
|
||||
if self.service_principal and self.shared_key:
|
||||
raise error.AztkError(
|
||||
"Both service_principal and shared_key auth are configured, must use only one"
|
||||
)
|
||||
elif self.service_principal:
|
||||
self.service_principal.validate()
|
||||
elif self.shared_key:
|
||||
self.shared_key.validate()
|
||||
else:
|
||||
raise error.AztkError(
|
||||
"Neither service_principal and shared_key auth are configured, must use only one"
|
||||
)
|
||||
|
||||
def is_aad(self):
|
||||
return self.service_principal is not None
|
||||
|
||||
|
||||
class VmImage:
|
||||
def __init__(self, publisher, offer, sku):
|
||||
self.publisher = publisher
|
||||
self.offer = offer
|
||||
self.sku = sku
|
||||
|
||||
|
||||
class Cluster:
|
||||
def __init__(self,
|
||||
pool: batch_models.CloudPool,
|
||||
nodes: batch_models.ComputeNodePaged = None):
|
||||
self.id = pool.id
|
||||
self.pool = pool
|
||||
self.nodes = nodes
|
||||
self.vm_size = pool.vm_size
|
||||
if pool.state.value is batch_models.PoolState.active:
|
||||
self.visible_state = pool.allocation_state.value
|
||||
else:
|
||||
self.visible_state = pool.state.value
|
||||
self.total_current_nodes = pool.current_dedicated_nodes + \
|
||||
pool.current_low_priority_nodes
|
||||
self.total_target_nodes = pool.target_dedicated_nodes + \
|
||||
pool.target_low_priority_nodes
|
||||
self.current_dedicated_nodes = pool.current_dedicated_nodes
|
||||
self.current_low_pri_nodes = pool.current_low_priority_nodes
|
||||
self.target_dedicated_nodes = pool.target_dedicated_nodes
|
||||
self.target_low_pri_nodes = pool.target_low_priority_nodes
|
||||
|
||||
|
||||
class SSHLog():
|
||||
def __init__(self, output, node_id):
|
||||
self.output = output
|
||||
self.node_id = node_id
|
||||
|
||||
|
||||
class Software:
|
||||
"""
|
||||
Enum with list of available softwares
|
||||
"""
|
||||
spark = "spark"
|
|
@ -46,9 +46,9 @@ class PluginManager:
|
|||
def get_args_for(self, cls):
|
||||
signature = inspect.signature(cls)
|
||||
args = dict()
|
||||
|
||||
for k, v in signature.parameters.items():
|
||||
args[k] = PluginArgument(k, default=v.default, required=v.default is inspect.Parameter.empty)
|
||||
for key, param in signature.parameters.items():
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.KEYWORD_ONLY:
|
||||
args[key] = PluginArgument(key, default=param.default, required=param.default is inspect.Parameter.empty)
|
||||
|
||||
return args
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import os
|
||||
|
||||
from aztk.error import InvalidModelError
|
||||
from aztk.internal import ConfigurationBase
|
||||
from aztk.core.models import Model, fields
|
||||
from aztk.models import PluginConfiguration
|
||||
from aztk.models.plugins import PluginFile, PluginTarget, PluginTargetRole
|
||||
|
||||
from .plugin_manager import plugin_manager
|
||||
|
||||
|
||||
class PluginReference(ConfigurationBase):
|
||||
class PluginReference(Model):
|
||||
"""
|
||||
Contains the configuration to use a plugin
|
||||
|
||||
|
@ -20,27 +20,12 @@ class PluginReference(ConfigurationBase):
|
|||
target_role (PluginTargetRole): Target role default to All nodes. This can only be used if providing a script
|
||||
args: (dict): If using name this is the arguments to pass to the plugin
|
||||
"""
|
||||
def __init__(self,
|
||||
name: str = None,
|
||||
script: str = None,
|
||||
target: PluginTarget = None,
|
||||
target_role: PluginTargetRole = None,
|
||||
args: dict = None):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.script = script
|
||||
self.target = target
|
||||
self.target_role = target_role
|
||||
self.args = args or dict()
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, args: dict):
|
||||
if "target" in args:
|
||||
args["target"] = PluginTarget(args["target"])
|
||||
if "target_role" in args:
|
||||
args["target_role"] = PluginTargetRole(args["target_role"])
|
||||
|
||||
return super()._from_dict(args)
|
||||
name = fields.String(default=None)
|
||||
script = fields.String(default=None)
|
||||
target = fields.Enum(PluginTarget, default=None)
|
||||
target_role = fields.Enum(PluginTargetRole, default=None)
|
||||
args = fields.String(default=None)
|
||||
|
||||
def get_plugin(self) -> PluginConfiguration:
|
||||
self.validate()
|
||||
|
@ -50,7 +35,7 @@ class PluginReference(ConfigurationBase):
|
|||
|
||||
return plugin_manager.get_plugin(self.name, self.args)
|
||||
|
||||
def validate(self) -> bool:
|
||||
def __validate__(self):
|
||||
if not self.name and not self.script:
|
||||
raise InvalidModelError("Plugin must either specify a name of an existing plugin or the path to a script.")
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from enum import Enum
|
||||
from typing import List, Union
|
||||
from aztk.internal import ConfigurationBase
|
||||
from aztk.error import InvalidPluginConfigurationError
|
||||
|
||||
from aztk.core.models import Model, fields
|
||||
|
||||
from .plugin_file import PluginFile
|
||||
|
||||
|
||||
class PluginTarget(Enum):
|
||||
"""
|
||||
Where this plugin should run
|
||||
|
@ -18,74 +19,51 @@ class PluginTargetRole(Enum):
|
|||
All = "all-nodes"
|
||||
|
||||
|
||||
class PluginPort:
|
||||
class PluginPort(Model):
|
||||
"""
|
||||
Definition for a port that should be opened on node
|
||||
:param internal: Port on the node
|
||||
:param public: [Optional] Port available to the user. If none won't open any port to the user
|
||||
:param name: [Optional] name to differentiate ports if you have multiple
|
||||
"""
|
||||
internal = fields.Integer()
|
||||
public = fields.Field(default=None)
|
||||
name = fields.Integer()
|
||||
|
||||
def __init__(self, internal: int, public: Union[int, bool]=False, name=None):
|
||||
self.internal = internal
|
||||
self.expose_publicly = bool(public)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.expose_publicly = bool(self.public)
|
||||
self.public_port = None
|
||||
if self.expose_publicly:
|
||||
if public is True:
|
||||
self.public_port = internal
|
||||
if self.public is True:
|
||||
self.public_port = self.internal
|
||||
else:
|
||||
self.public_port = public
|
||||
|
||||
self.name = name
|
||||
self.public_port = self.public
|
||||
|
||||
|
||||
|
||||
|
||||
class PluginConfiguration(ConfigurationBase):
|
||||
class PluginConfiguration(Model):
|
||||
"""
|
||||
Plugin manifest that should be returned in the main.py of your plugin
|
||||
:param name: Name of the plugin. Used to reference the plugin
|
||||
:param runOn: Where the plugin should run
|
||||
:param files: List of files to upload
|
||||
:param args:
|
||||
:param env:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
ports: List[PluginPort] = None,
|
||||
files: List[PluginFile] = None,
|
||||
execute: str = None,
|
||||
args=None,
|
||||
env=None,
|
||||
target_role: PluginTargetRole = None,
|
||||
target: PluginTarget = None):
|
||||
self.name = name
|
||||
# self.docker_image = docker_image
|
||||
self.target = target or PluginTarget.SparkContainer
|
||||
self.target_role = target_role or PluginTargetRole.Master
|
||||
self.ports = ports or []
|
||||
self.files = files or []
|
||||
self.args = args or []
|
||||
self.env = env or dict()
|
||||
self.execute = execute
|
||||
Args
|
||||
name: Name of the plugin. Used to reference the plugin
|
||||
runOn: Where the plugin should run
|
||||
execute: Path to the file to execute(This must match the target of one of the files)
|
||||
files: List of files to upload
|
||||
args: List of argumenets to pass to the executing script
|
||||
env: Dict of environment variables to pass to the script
|
||||
"""
|
||||
name = fields.String()
|
||||
files = fields.List(PluginFile)
|
||||
execute = fields.String()
|
||||
args = fields.List(default=[])
|
||||
env = fields.List(default=[])
|
||||
target = fields.Enum(PluginTarget, default=PluginTarget.SparkContainer)
|
||||
target_role = fields.Enum(PluginTargetRole, default=PluginTargetRole.Master)
|
||||
ports = fields.List(PluginPort, default=[])
|
||||
|
||||
def has_arg(self, name: str):
|
||||
for x in self.args:
|
||||
if x.name == name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def validate(self):
|
||||
self._validate_required([
|
||||
"name",
|
||||
"execute",
|
||||
])
|
||||
|
||||
if not isinstance(self.target, PluginTarget):
|
||||
raise InvalidPluginConfigurationError(
|
||||
"Target must be of type Plugin target but was {0}".format(self.target))
|
||||
|
||||
if not isinstance(self.target_role, PluginTargetRole):
|
||||
raise InvalidPluginConfigurationError(
|
||||
"Target role must be of type Plugin target role but was {0}".format(self.target))
|
||||
|
|
|
@ -1,24 +1,36 @@
|
|||
import io
|
||||
from typing import Union
|
||||
from aztk.core.models import Model, fields
|
||||
|
||||
class PluginFile:
|
||||
class PluginFile(Model):
|
||||
"""
|
||||
Reference to a file for a plugin.
|
||||
"""
|
||||
def __init__(self, target: str, local_path: str):
|
||||
self.target = target
|
||||
self.local_path = local_path
|
||||
|
||||
target = fields.String()
|
||||
local_path = fields.String()
|
||||
|
||||
# TODO handle folders?
|
||||
def __init__(self, target: str = None, local_path: str = None):
|
||||
super().__init__(target=target, local_path=local_path)
|
||||
|
||||
def content(self):
|
||||
with open(self.local_path, "r", encoding='UTF-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class TextPluginFile:
|
||||
class TextPluginFile(Model):
|
||||
"""
|
||||
Reference to a file for a plugin.
|
||||
|
||||
Args:
|
||||
target (str): Where should the file be uploaded relative to the plugin working dir
|
||||
content (str|io.StringIO): Content of the file. Can either be a string or a StringIO
|
||||
"""
|
||||
|
||||
target = fields.String()
|
||||
|
||||
def __init__(self, target: str, content: Union[str,io.StringIO]):
|
||||
super().__init__(target=target)
|
||||
if isinstance(content, str):
|
||||
self._content = content
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
class RemoteLogin:
|
||||
def __init__(self, ip_address, port):
|
||||
self.ip_address = ip_address
|
||||
self.port = port
|
|
@ -0,0 +1,59 @@
|
|||
from aztk.core.models import Model, fields
|
||||
from aztk.error import InvalidModelError
|
||||
|
||||
class ServicePrincipalConfiguration(Model):
|
||||
"""
|
||||
Container class for AAD authentication
|
||||
"""
|
||||
tenant_id = fields.String()
|
||||
client_id = fields.String()
|
||||
credential = fields.String()
|
||||
batch_account_resource_id = fields.String()
|
||||
storage_account_resource_id = fields.String()
|
||||
|
||||
class SharedKeyConfiguration(Model):
|
||||
"""
|
||||
Container class for shared key authentication
|
||||
"""
|
||||
batch_account_name = fields.String()
|
||||
batch_account_key = fields.String()
|
||||
batch_service_url = fields.String()
|
||||
storage_account_name = fields.String()
|
||||
storage_account_key = fields.String()
|
||||
storage_account_suffix = fields.String()
|
||||
|
||||
|
||||
class DockerConfiguration(Model):
|
||||
"""
|
||||
Configuration for connecting to private docker
|
||||
|
||||
Args:
|
||||
endpoint (str): Which docker endpoint to use. Default to docker hub.
|
||||
username (str): Docker endpoint username
|
||||
password (str): Docker endpoint password
|
||||
"""
|
||||
endpoint = fields.String(default=None)
|
||||
username = fields.String(default=None)
|
||||
password = fields.String(default=None)
|
||||
|
||||
|
||||
class SecretsConfiguration(Model):
|
||||
service_principal = fields.Model(ServicePrincipalConfiguration, default=None)
|
||||
shared_key = fields.Model(SharedKeyConfiguration, default=None)
|
||||
docker = fields.Model(DockerConfiguration, default=None)
|
||||
ssh_pub_key = fields.String(default=None)
|
||||
ssh_priv_key = fields.String(default=None)
|
||||
|
||||
def __validate__(self):
|
||||
if self.service_principal and self.shared_key:
|
||||
raise InvalidModelError(
|
||||
"Both service_principal and shared_key auth are configured, must use only one"
|
||||
)
|
||||
|
||||
if not self.service_principal and not self.shared_key:
|
||||
raise InvalidModelError(
|
||||
"Neither service_principal and shared_key auth are configured, must use only one"
|
||||
)
|
||||
|
||||
def is_aad(self):
|
||||
return self.service_principal is not None
|
|
@ -0,0 +1,5 @@
|
|||
class Software:
|
||||
"""
|
||||
Enum with list of available softwares
|
||||
"""
|
||||
spark = "spark"
|
|
@ -0,0 +1,4 @@
|
|||
class SSHLog():
|
||||
def __init__(self, output, node_id):
|
||||
self.output = output
|
||||
self.node_id = node_id
|
|
@ -1,7 +1,6 @@
|
|||
from aztk.internal import ConfigurationBase
|
||||
from aztk.error import InvalidModelError
|
||||
from aztk.utils import constants, deprecate
|
||||
|
||||
from aztk.core.models import Model, fields
|
||||
|
||||
class ToolkitDefinition:
|
||||
def __init__(self, versions, environments):
|
||||
|
@ -26,7 +25,7 @@ TOOLKIT_MAP = dict(
|
|||
)
|
||||
|
||||
|
||||
class Toolkit(ConfigurationBase):
|
||||
class Toolkit(Model):
|
||||
"""
|
||||
Toolkit for a cluster.
|
||||
This will help pick the docker image needed
|
||||
|
@ -36,24 +35,16 @@ class Toolkit(ConfigurationBase):
|
|||
version (str): Version of the toolkit
|
||||
environment (str): Which environment to use for this toolkit
|
||||
environment_version (str): If there is multiple version for an environment you can specify which one
|
||||
docker_repo (str): Optional docker repo
|
||||
"""
|
||||
def __init__(self,
|
||||
software: str,
|
||||
version: str,
|
||||
environment: str = None,
|
||||
environment_version: str = None,
|
||||
docker_repo=None):
|
||||
|
||||
self.software = software
|
||||
self.version = str(version)
|
||||
self.environment = environment
|
||||
self.environment_version = environment_version
|
||||
self.docker_repo = docker_repo
|
||||
|
||||
|
||||
def validate(self):
|
||||
self._validate_required(["software", "version"])
|
||||
software = fields.String()
|
||||
version = fields.String()
|
||||
environment = fields.String(default=None)
|
||||
environment_version = fields.String(default=None)
|
||||
docker_repo = fields.String(default=None)
|
||||
|
||||
def __validate__(self):
|
||||
if self.software not in TOOLKIT_MAP:
|
||||
raise InvalidModelError("Toolkit '{0}' is not in the list of allowed toolkits {1}".format(
|
||||
self.software, list(TOOLKIT_MAP.keys())))
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from aztk.core.models import Model, fields
|
||||
|
||||
class UserConfiguration(Model):
|
||||
username = fields.String()
|
||||
ssh_key = fields.String(default=None)
|
||||
password = fields.String(default=None)
|
|
@ -0,0 +1,5 @@
|
|||
class VmImage:
|
||||
def __init__(self, publisher, offer, sku):
|
||||
self.publisher = publisher
|
||||
self.offer = offer
|
||||
self.sku = sku
|
|
@ -13,11 +13,6 @@ from aztk.spark.helpers import cluster_diagnostic_helper
|
|||
from aztk.spark.utils import util
|
||||
from aztk.internal.cluster_data import NodeData
|
||||
|
||||
|
||||
DEFAULT_CLUSTER_CONFIG = models.ClusterConfiguration(
|
||||
worker_on_master=True,
|
||||
)
|
||||
|
||||
class Client(BaseClient):
|
||||
"""
|
||||
Aztk Spark Client
|
||||
|
@ -29,7 +24,7 @@ class Client(BaseClient):
|
|||
def __init__(self, secrets_config):
|
||||
super().__init__(secrets_config)
|
||||
|
||||
def create_cluster(self, configuration: models.ClusterConfiguration, wait: bool = False):
|
||||
def create_cluster(self, cluster_conf: models.ClusterConfiguration, wait: bool = False):
|
||||
"""
|
||||
Create a new aztk spark cluster
|
||||
|
||||
|
@ -40,10 +35,8 @@ class Client(BaseClient):
|
|||
Returns:
|
||||
aztk.spark.models.Cluster
|
||||
"""
|
||||
cluster_conf = models.ClusterConfiguration()
|
||||
cluster_conf.merge(DEFAULT_CLUSTER_CONFIG)
|
||||
cluster_conf.merge(configuration)
|
||||
cluster_conf.validate()
|
||||
|
||||
cluster_data = self._get_cluster_data(cluster_conf.cluster_id)
|
||||
try:
|
||||
zip_resource_files = None
|
||||
|
|
|
@ -4,6 +4,7 @@ import azure.batch.models as batch_models
|
|||
import aztk.models
|
||||
from aztk import error
|
||||
from aztk.utils import constants, helpers
|
||||
from aztk.core.models import Model, fields
|
||||
|
||||
class SparkToolkit(aztk.models.Toolkit):
|
||||
def __init__(self, version: str, environment: str = None, environment_version: str = None):
|
||||
|
@ -53,12 +54,14 @@ class File(aztk.models.File):
|
|||
pass
|
||||
|
||||
|
||||
class SparkConfiguration():
|
||||
def __init__(self, spark_defaults_conf=None, spark_env_sh=None, core_site_xml=None, jars=None):
|
||||
self.spark_defaults_conf = spark_defaults_conf
|
||||
self.spark_env_sh = spark_env_sh
|
||||
self.core_site_xml = core_site_xml
|
||||
self.jars = jars
|
||||
class SparkConfiguration(Model):
|
||||
spark_defaults_conf = fields.String()
|
||||
spark_env_sh = fields.String()
|
||||
core_site_xml = fields.String()
|
||||
jars = fields.List()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ssh_key_pair = self.__generate_ssh_key_pair()
|
||||
|
||||
def __generate_ssh_key_pair(self):
|
||||
|
@ -96,37 +99,8 @@ class PluginConfiguration(aztk.models.PluginConfiguration):
|
|||
|
||||
|
||||
class ClusterConfiguration(aztk.models.ClusterConfiguration):
|
||||
def __init__(
|
||||
self,
|
||||
custom_scripts: List[CustomScript] = None,
|
||||
file_shares: List[FileShare] = None,
|
||||
cluster_id: str = None,
|
||||
vm_count=0,
|
||||
vm_low_pri_count=0,
|
||||
vm_size=None,
|
||||
subnet_id=None,
|
||||
toolkit: SparkToolkit = None,
|
||||
user_configuration: UserConfiguration = None,
|
||||
spark_configuration: SparkConfiguration = None,
|
||||
worker_on_master: bool = None):
|
||||
super().__init__(
|
||||
custom_scripts=custom_scripts,
|
||||
cluster_id=cluster_id,
|
||||
vm_count=vm_count,
|
||||
vm_low_pri_count=vm_low_pri_count,
|
||||
vm_size=vm_size,
|
||||
toolkit=toolkit,
|
||||
subnet_id=subnet_id,
|
||||
file_shares=file_shares,
|
||||
user_configuration=user_configuration,
|
||||
)
|
||||
self.spark_configuration = spark_configuration
|
||||
self.worker_on_master = worker_on_master
|
||||
|
||||
def merge(self, other):
|
||||
super().merge(other)
|
||||
self._merge_attributes(other, ["spark_configuration", "worker_on_master"])
|
||||
|
||||
spark_configuration = fields.Model(SparkConfiguration, default=None)
|
||||
worker_on_master = fields.Boolean(default=True)
|
||||
|
||||
class SecretsConfiguration(aztk.models.SecretsConfiguration):
|
||||
pass
|
||||
|
@ -234,8 +208,8 @@ class JobConfiguration:
|
|||
custom_scripts=self.custom_scripts,
|
||||
toolkit=self.toolkit,
|
||||
vm_size=self.vm_size,
|
||||
vm_count=self.max_dedicated_nodes,
|
||||
vm_low_pri_count=self.max_low_pri_nodes,
|
||||
size=self.max_dedicated_nodes,
|
||||
size_low_priority=self.max_low_pri_nodes,
|
||||
subnet_id=self.subnet_id,
|
||||
worker_on_master=self.worker_on_master,
|
||||
spark_configuration=self.spark_configuration,
|
||||
|
|
|
@ -4,19 +4,18 @@ from aztk.models.plugins.plugin_file import PluginFile
|
|||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
class JupyterPlugin(PluginConfiguration):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="jupyter",
|
||||
ports=[
|
||||
PluginPort(
|
||||
internal=8888,
|
||||
public=True,
|
||||
),
|
||||
],
|
||||
target_role=PluginTargetRole.All,
|
||||
execute="jupyter.sh",
|
||||
files=[
|
||||
PluginFile("jupyter.sh", os.path.join(dir_path, "jupyter.sh")),
|
||||
],
|
||||
)
|
||||
def JupyterPlugin():
|
||||
return PluginConfiguration(
|
||||
name="jupyter",
|
||||
ports=[
|
||||
PluginPort(
|
||||
internal=8888,
|
||||
public=True,
|
||||
),
|
||||
],
|
||||
target_role=PluginTargetRole.All,
|
||||
execute="jupyter.sh",
|
||||
files=[
|
||||
PluginFile("jupyter.sh", os.path.join(dir_path, "jupyter.sh")),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -5,19 +5,18 @@ from aztk.utils import constants
|
|||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
class JupyterLabPlugin(PluginConfiguration):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="jupyterlab",
|
||||
ports=[
|
||||
PluginPort(
|
||||
internal=8889,
|
||||
public=True,
|
||||
),
|
||||
],
|
||||
target_role=PluginTargetRole.All,
|
||||
execute="jupyter_lab.sh",
|
||||
files=[
|
||||
PluginFile("jupyter_lab.sh", os.path.join(dir_path, "jupyter_lab.sh")),
|
||||
],
|
||||
)
|
||||
def JupyterLabPlugin():
|
||||
return PluginConfiguration(
|
||||
name="jupyterlab",
|
||||
ports=[
|
||||
PluginPort(
|
||||
internal=8889,
|
||||
public=True,
|
||||
),
|
||||
],
|
||||
target_role=PluginTargetRole.All,
|
||||
execute="jupyter_lab.sh",
|
||||
files=[
|
||||
PluginFile("jupyter_lab.sh", os.path.join(dir_path, "jupyter_lab.sh")),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
import os
|
||||
from aztk.models.plugins.plugin_configuration import PluginConfiguration, PluginPort, PluginTargetRole
|
||||
from aztk.models.plugins.plugin_file import PluginFile
|
||||
from aztk.utils import constants
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
class RStudioServerPlugin(PluginConfiguration):
|
||||
def __init__(self, version="1.1.383"):
|
||||
super().__init__(
|
||||
name="rstudio_server",
|
||||
ports=[
|
||||
PluginPort(
|
||||
internal=8787,
|
||||
public=True,
|
||||
),
|
||||
],
|
||||
target_role=PluginTargetRole.Master,
|
||||
execute="rstudio_server.sh",
|
||||
files=[
|
||||
PluginFile("rstudio_server.sh", os.path.join(dir_path, "rstudio_server.sh")),
|
||||
],
|
||||
env=dict(RSTUDIO_SERVER_VERSION=version),
|
||||
)
|
||||
def RStudioServerPlugin(version="1.1.383"):
|
||||
return PluginConfiguration(
|
||||
name="rstudio_server",
|
||||
ports=[
|
||||
PluginPort(
|
||||
internal=8787,
|
||||
public=True,
|
||||
),
|
||||
],
|
||||
target_role=PluginTargetRole.Master,
|
||||
execute="rstudio_server.sh",
|
||||
files=[
|
||||
PluginFile("rstudio_server.sh", os.path.join(dir_path, "rstudio_server.sh")),
|
||||
],
|
||||
env=dict(RSTUDIO_SERVER_VERSION=version),
|
||||
)
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
import os
|
||||
import yaml
|
||||
from aztk_cli import log
|
||||
import aztk.spark
|
||||
from aztk.spark.models import (
|
||||
SecretsConfiguration,
|
||||
ServicePrincipalConfiguration,
|
||||
SharedKeyConfiguration,
|
||||
DockerConfiguration,
|
||||
ClusterConfiguration,
|
||||
UserConfiguration,
|
||||
)
|
||||
from aztk.utils import deprecate
|
||||
from aztk.models import Toolkit
|
||||
from aztk.models.plugins.internal import PluginReference
|
||||
|
||||
|
@ -19,11 +15,9 @@ def load_aztk_secrets() -> SecretsConfiguration:
|
|||
"""
|
||||
secrets = SecretsConfiguration()
|
||||
# read global ~/secrets.yaml
|
||||
global_config = _load_secrets_config(
|
||||
os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, '.aztk',
|
||||
'secrets.yaml'))
|
||||
global_config = _load_config_file(os.path.join(aztk.utils.constants.HOME_DIRECTORY_PATH, '.aztk', 'secrets.yaml'))
|
||||
# read current working directory secrets.yaml
|
||||
local_config = _load_secrets_config()
|
||||
local_config = _load_config_file(aztk.utils.constants.DEFAULT_SECRETS_PATH)
|
||||
|
||||
if not global_config and not local_config:
|
||||
raise aztk.error.AztkError("There is no secrets.yaml in either ./.aztk/secrets.yaml or .aztk/secrets.yaml")
|
||||
|
@ -37,12 +31,7 @@ def load_aztk_secrets() -> SecretsConfiguration:
|
|||
secrets.validate()
|
||||
return secrets
|
||||
|
||||
|
||||
def _load_secrets_config(
|
||||
path: str = aztk.utils.constants.DEFAULT_SECRETS_PATH):
|
||||
"""
|
||||
Loads the secrets.yaml file in the .aztk directory
|
||||
"""
|
||||
def _load_config_file(path: str):
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
|
@ -51,74 +40,16 @@ def _load_secrets_config(
|
|||
return yaml.load(stream)
|
||||
except yaml.YAMLError as err:
|
||||
raise aztk.error.AztkError(
|
||||
"Error in secrets.yaml: {0}".format(err))
|
||||
"Error in {0}:\n {1}".format(path, err))
|
||||
|
||||
|
||||
def _merge_secrets_dict(secrets: SecretsConfiguration, secrets_config):
|
||||
service_principal_config = secrets_config.get('service_principal')
|
||||
if service_principal_config:
|
||||
secrets.service_principal = ServicePrincipalConfiguration(
|
||||
tenant_id=service_principal_config.get('tenant_id'),
|
||||
client_id=service_principal_config.get('client_id'),
|
||||
credential=service_principal_config.get('credential'),
|
||||
batch_account_resource_id=service_principal_config.get(
|
||||
'batch_account_resource_id'),
|
||||
storage_account_resource_id=service_principal_config.get(
|
||||
'storage_account_resource_id'),
|
||||
)
|
||||
|
||||
shared_key_config = secrets_config.get('shared_key')
|
||||
batch = secrets_config.get('batch')
|
||||
storage = secrets_config.get('storage')
|
||||
|
||||
if shared_key_config and (batch or storage):
|
||||
raise aztk.error.AztkError(
|
||||
"Shared keys must be configured either under 'sharedKey:' or under 'batch:' and 'storage:', not both."
|
||||
)
|
||||
|
||||
if shared_key_config:
|
||||
secrets.shared_key = SharedKeyConfiguration(
|
||||
batch_account_name=shared_key_config.get('batch_account_name'),
|
||||
batch_account_key=shared_key_config.get('batch_account_key'),
|
||||
batch_service_url=shared_key_config.get('batch_service_url'),
|
||||
storage_account_name=shared_key_config.get('storage_account_name'),
|
||||
storage_account_key=shared_key_config.get('storage_account_key'),
|
||||
storage_account_suffix=shared_key_config.get(
|
||||
'storage_account_suffix'),
|
||||
)
|
||||
elif batch or storage:
|
||||
secrets.shared_key = SharedKeyConfiguration()
|
||||
if batch:
|
||||
log.warning(
|
||||
"Your secrets.yaml format is deprecated. To use shared key authentication use the shared_key key. See config/secrets.yaml.template"
|
||||
)
|
||||
secrets.shared_key.batch_account_name = batch.get(
|
||||
'batchaccountname')
|
||||
secrets.shared_key.batch_account_key = batch.get('batchaccountkey')
|
||||
secrets.shared_key.batch_service_url = batch.get('batchserviceurl')
|
||||
|
||||
if storage:
|
||||
secrets.shared_key.storage_account_name = storage.get(
|
||||
'storageaccountname')
|
||||
secrets.shared_key.storage_account_key = storage.get(
|
||||
'storageaccountkey')
|
||||
secrets.shared_key.storage_account_suffix = storage.get(
|
||||
'storageaccountsuffix')
|
||||
|
||||
docker_config = secrets_config.get('docker')
|
||||
if docker_config:
|
||||
secrets.docker = DockerConfiguration(
|
||||
endpoint=docker_config.get('endpoint'),
|
||||
username=docker_config.get('username'),
|
||||
password=docker_config.get('password'),
|
||||
)
|
||||
|
||||
default_config = secrets_config.get('default')
|
||||
# Check for ssh keys if they are provided
|
||||
if default_config:
|
||||
secrets.ssh_priv_key = default_config.get('ssh_priv_key')
|
||||
secrets.ssh_pub_key = default_config.get('ssh_pub_key')
|
||||
if 'default' in secrets_config:
|
||||
deprecate("default key in secrets.yaml is deprecated. Place all child parameters directly at the root")
|
||||
secrets_config = dict(**secrets_config, **secrets_config.pop('default'))
|
||||
|
||||
other = SecretsConfiguration.from_dict(secrets_config)
|
||||
secrets.merge(other)
|
||||
|
||||
def read_cluster_config(
|
||||
path: str = aztk.utils.constants.DEFAULT_CLUSTER_CONFIG_PATH
|
||||
|
@ -126,82 +57,25 @@ def read_cluster_config(
|
|||
"""
|
||||
Reads the config file in the .aztk/ directory (.aztk/cluster.yaml)
|
||||
"""
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, 'r', encoding='UTF-8') as stream:
|
||||
try:
|
||||
config_dict = yaml.load(stream)
|
||||
except yaml.YAMLError as err:
|
||||
raise aztk.error.AztkError(
|
||||
"Error in cluster.yaml: {0}".format(err))
|
||||
|
||||
if config_dict is None:
|
||||
return None
|
||||
|
||||
return cluster_config_from_dict(config_dict)
|
||||
|
||||
config_dict = _load_config_file(path)
|
||||
return cluster_config_from_dict(config_dict)
|
||||
|
||||
def cluster_config_from_dict(config: dict):
|
||||
output = ClusterConfiguration()
|
||||
wait = False
|
||||
if config.get('id') is not None:
|
||||
output.cluster_id = config['id']
|
||||
|
||||
if config.get('vm_size') is not None:
|
||||
output.vm_size = config['vm_size']
|
||||
|
||||
if config.get('size'):
|
||||
output.vm_count = config['size']
|
||||
|
||||
if config.get('size_low_pri'):
|
||||
output.vm_low_pri_count = config['size_low_pri']
|
||||
|
||||
if config.get('subnet_id') is not None:
|
||||
output.subnet_id = config['subnet_id']
|
||||
|
||||
if config.get('username') is not None:
|
||||
output.user_configuration = UserConfiguration(
|
||||
username=config['username'])
|
||||
|
||||
if config.get('password') is not None:
|
||||
output.user_configuration.password = config['password']
|
||||
|
||||
if config.get('custom_scripts') not in [[None], None]:
|
||||
output.custom_scripts = []
|
||||
for custom_script in config['custom_scripts']:
|
||||
output.custom_scripts.append(
|
||||
aztk.spark.models.CustomScript(
|
||||
script=custom_script['script'],
|
||||
run_on=custom_script['runOn']))
|
||||
|
||||
if config.get('azure_files') not in [[None], None]:
|
||||
output.file_shares = []
|
||||
for file_share in config['azure_files']:
|
||||
output.file_shares.append(
|
||||
aztk.spark.models.FileShare(
|
||||
storage_account_name=file_share['storage_account_name'],
|
||||
storage_account_key=file_share['storage_account_key'],
|
||||
file_share_path=file_share['file_share_path'],
|
||||
mount_path=file_share['mount_path'],
|
||||
))
|
||||
|
||||
if config.get('toolkit') is not None:
|
||||
output.toolkit = Toolkit.from_dict(config['toolkit'])
|
||||
|
||||
if config.get('plugins') not in [[None], None]:
|
||||
output.plugins = []
|
||||
plugins = []
|
||||
for plugin in config['plugins']:
|
||||
ref = PluginReference.from_dict(plugin)
|
||||
output.plugins.append(ref.get_plugin())
|
||||
plugins.append(ref.get_plugin())
|
||||
config["plugins"] = plugins
|
||||
|
||||
if config.get('worker_on_master') is not None:
|
||||
output.worker_on_master = config['worker_on_master']
|
||||
if config.get('username') is not None:
|
||||
config['user_configuration'] = dict(username=config.pop('username'))
|
||||
|
||||
if config.get('wait') is not None:
|
||||
wait = config['wait']
|
||||
wait = config.pop('wait')
|
||||
|
||||
return output, wait
|
||||
return ClusterConfiguration.from_dict(config), wait
|
||||
|
||||
|
||||
class SshConfig:
|
||||
|
@ -321,8 +195,8 @@ class JobConfig():
|
|||
self.toolkit = Toolkit.from_dict(cluster_configuration.get('toolkit'))
|
||||
if cluster_configuration.get('size') is not None:
|
||||
self.max_dedicated_nodes = cluster_configuration.get('size')
|
||||
if cluster_configuration.get('size_low_pri') is not None:
|
||||
self.max_low_pri_nodes = cluster_configuration.get('size_low_pri')
|
||||
if cluster_configuration.get('size_low_priority') is not None:
|
||||
self.max_low_pri_nodes = cluster_configuration.get('size_low_priority')
|
||||
self.custom_scripts = cluster_configuration.get('custom_scripts')
|
||||
self.subnet_id = cluster_configuration.get('subnet_id')
|
||||
self.worker_on_master = cluster_configuration.get("worker_on_master")
|
||||
|
|
|
@ -21,7 +21,7 @@ vm_size: standard_f2
|
|||
# size: <number of dedicated nodes in the cluster, not that clusters must contain all dedicated or all low priority nodes>
|
||||
size: 2
|
||||
|
||||
# size_low_pri: <number of low priority nodes in the cluster, mutually exclusive with size setting>
|
||||
# size_low_priority: <number of low priority nodes in the cluster, mutually exclusive with size setting>
|
||||
|
||||
|
||||
# username: <username for the linux user to be created> (optional)
|
||||
|
|
|
@ -8,7 +8,7 @@ job:
|
|||
cluster_configuration:
|
||||
vm_size: standard_d2_v2
|
||||
size: 2
|
||||
size_low_pri: 0
|
||||
size_low_priority: 0
|
||||
subnet_id:
|
||||
# Toolkit configuration [Required]
|
||||
toolkit:
|
||||
|
|
|
@ -24,7 +24,7 @@ docker:
|
|||
# password:
|
||||
# endpoint:
|
||||
|
||||
default:
|
||||
# SSH keys used to create a user and connect to a server.
|
||||
# The public key can either be the public key itself (ssh-rsa ...) or the path to the ssh key.
|
||||
# ssh_pub_key: ~/.ssh/id_rsa.pub
|
||||
|
||||
# SSH keys used to create a user and connect to a server.
|
||||
# The public key can either be the public key itself (ssh-rsa ...) or the path to the ssh key.
|
||||
# ssh_pub_key: ~/.ssh/id_rsa.pub
|
||||
|
|
|
@ -29,7 +29,7 @@ def setup_parser(parser: argparse.ArgumentParser):
|
|||
|
||||
parser.add_argument('--no-wait', dest='wait', action='store_false')
|
||||
parser.add_argument('--wait', dest='wait', action='store_true')
|
||||
parser.set_defaults(wait=None, size=None, size_low_pri=None)
|
||||
parser.set_defaults(wait=None, size=None, size_low_priority=None)
|
||||
|
||||
|
||||
def execute(args: typing.NamedTuple):
|
||||
|
@ -42,8 +42,8 @@ def execute(args: typing.NamedTuple):
|
|||
cluster_conf.merge(file_config)
|
||||
cluster_conf.merge(ClusterConfiguration(
|
||||
cluster_id=args.cluster_id,
|
||||
vm_count=args.size,
|
||||
vm_low_pri_count=args.size_low_pri,
|
||||
size=args.size,
|
||||
size_low_priority=args.size_low_priority,
|
||||
vm_size=args.vm_size,
|
||||
subnet_id=args.subnet_id,
|
||||
user_configuration=UserConfiguration(
|
||||
|
@ -71,6 +71,7 @@ def execute(args: typing.NamedTuple):
|
|||
else:
|
||||
cluster_conf.user_configuration = None
|
||||
|
||||
cluster_conf.validate()
|
||||
utils.print_cluster_conf(cluster_conf, wait)
|
||||
with utils.Spinner():
|
||||
# create spark cluster
|
||||
|
|
|
@ -3,6 +3,7 @@ import getpass
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
from subprocess import call
|
||||
from typing import List
|
||||
import azure.batch.models as batch_models
|
||||
|
@ -420,14 +421,12 @@ def utc_to_local(utc_dt):
|
|||
|
||||
def print_cluster_conf(cluster_conf: ClusterConfiguration, wait: bool):
|
||||
user_configuration = cluster_conf.user_configuration
|
||||
|
||||
log.info("-------------------------------------------")
|
||||
log.info("cluster id: %s", cluster_conf.cluster_id)
|
||||
log.info("cluster toolkit: %s %s", cluster_conf.toolkit.software, cluster_conf.toolkit.version)
|
||||
log.info("cluster size: %s",
|
||||
cluster_conf.vm_count + cluster_conf.vm_low_pri_count)
|
||||
log.info("> dedicated: %s", cluster_conf.vm_count)
|
||||
log.info("> low priority: %s", cluster_conf.vm_low_pri_count)
|
||||
log.info("cluster toolkit: %s %s", cluster_conf.toolkit.software, cluster_conf.toolkit.version)
|
||||
log.info("cluster size: %s", cluster_conf.size + cluster_conf.size_low_priority)
|
||||
log.info("> dedicated: %s", cluster_conf.size)
|
||||
log.info("> low priority: %s", cluster_conf.size_low_priority)
|
||||
log.info("cluster vm size: %s", cluster_conf.vm_size)
|
||||
log.info("custom scripts: %s", len(cluster_conf.custom_scripts) if cluster_conf.custom_scripts else 0)
|
||||
log.info("subnet ID: %s", cluster_conf.subnet_id)
|
||||
|
|
|
@ -47,12 +47,11 @@ To finish setting up, you need to fill out your Azure Batch and Azure Storage se
|
|||
|
||||
Please note that if you use ssh keys and a have a non-standard ssh key file name or path, you will need to specify the location of your ssh public and private keys. To do so, set them as shown below:
|
||||
```yaml
|
||||
default:
|
||||
# SSH keys used to create a user and connect to a server.
|
||||
# The public key can either be the public key itself(ssh-rsa ...) or the path to the ssh key.
|
||||
# The private key must be the path to the key.
|
||||
ssh_pub_key: ~/.ssh/my-public-key.pub
|
||||
ssh_priv_key: ~/.ssh/my-private-key
|
||||
# SSH keys used to create a user and connect to a server.
|
||||
# The public key can either be the public key itself(ssh-rsa ...) or the path to the ssh key.
|
||||
# The private key must be the path to the key.
|
||||
ssh_pub_key: ~/.ssh/my-public-key.pub
|
||||
ssh_priv_key: ~/.ssh/my-private-key
|
||||
```
|
||||
|
||||
0. Log into Azure
|
||||
|
|
|
@ -29,7 +29,7 @@ vm_size: standard_a2
|
|||
# size: <number of dedicated nodes in the cluster, not that clusters must contain all dedicated or all low priority nodes>
|
||||
size: 2
|
||||
|
||||
# size_low_pri: <number of low priority nodes in the cluster, mutually exclusive with size setting>
|
||||
# size_low_priority: <number of low priority nodes in the cluster, mutually exclusive with size setting>
|
||||
|
||||
# username: <username for the linux user to be created> (optional)
|
||||
username: spark
|
||||
|
|
|
@ -56,7 +56,7 @@ Jobs also require a definition of the cluster on which the Applications will run
|
|||
```
|
||||
_Please Note: For more information about Azure VM sizes, see [Azure Batch Pricing](https://azure.microsoft.com/en-us/pricing/details/batch/). And for more information about Docker repositories see [Docker](./12-docker-iamge.html)._
|
||||
|
||||
_The only required fields are vm_size and either size or size_low_pri, all other fields can be left blank or removed._
|
||||
_The only required fields are vm_size and either size or size_low_priority, all other fields can be left blank or removed._
|
||||
|
||||
A Job definition may also include a default Spark Configuration. The following are the properties to define a Spark Configuration:
|
||||
```yaml
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
aztk.models package
|
||||
===================
|
||||
|
||||
aztk.models.models module
|
||||
-------------------------
|
||||
|
||||
.. automodule:: aztk.models.models
|
||||
.. automodule:: aztk.models
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
||||
|
|
|
@ -6,57 +6,50 @@ In `aztk/models` create a new file with the name of your model `my_model.py`
|
|||
|
||||
In `aztk/models/__init__.py` add `from .my_model import MyModel`
|
||||
|
||||
Create a new class `MyModel` that inherit `ConfigurationBase`
|
||||
Create a new class `MyModel` that inherit `Modle`
|
||||
```python
|
||||
from aztk.internal import ConfigurationBase
|
||||
from aztk.core.models import Model, fields
|
||||
|
||||
class MyModel(ConfigurationBase):
|
||||
class MyModel(Model):
|
||||
"""
|
||||
MyModel is an sample model
|
||||
|
||||
Args:
|
||||
input1 (str): This is the first input
|
||||
"""
|
||||
def __init__(self, input1: str):
|
||||
self.input1 = input1
|
||||
|
||||
def validate(self):
|
||||
input1 = fields.String()
|
||||
|
||||
def __validate__(self):
|
||||
pass
|
||||
|
||||
```
|
||||
|
||||
### Available fields types
|
||||
|
||||
Check `aztk/core/models/fields.py` for the sources
|
||||
|
||||
* `Field`: Base field class
|
||||
* `String`: Field that validate it is given a string
|
||||
* `Integer`: Field that validate it is given a int
|
||||
* `Float`: Field that validate it is given a float
|
||||
* `Boolean`: Field that validate it is given a boolean
|
||||
* `List`: Field that validate it is given a list and can also automatically convert entries to the given model type.
|
||||
* `Model`: Field that map to another model. If passed a dict it will automatically try to convert to the Model type
|
||||
* `Enum`: Field which value should be an enum. It will convert automatically to the enum if given the value.
|
||||
|
||||
## Add validation
|
||||
The fields provide basic validation automatically. A field without a default will be marked as required.
|
||||
|
||||
In `def validate` do any kind of checks and raise a `InvalidModelError` if there is any problems with the values
|
||||
|
||||
### Validate required
|
||||
To validate required attributes call the parent `_validate_required` method. Method takes a list of attributes which should not be None
|
||||
To provide model wide validation implement a `__validate__` method and raise a `InvalidModelError` if there is any problems with the values
|
||||
|
||||
```python
|
||||
def validate(self) -> bool:
|
||||
self._validate_required(["input1"])
|
||||
```
|
||||
|
||||
### Custom validation
|
||||
```python
|
||||
def validate(self) -> bool:
|
||||
if "foo" in self.input1:
|
||||
raise InvalidModelError("foo cannot be in input1")
|
||||
def __validate__(self):
|
||||
if 'secret' in self.input1:
|
||||
raise InvalidModelError("Input1 contains secrets")
|
||||
|
||||
```
|
||||
|
||||
## Convert dict to model
|
||||
|
||||
When inheriting from `ConfigurationBase` it comes with a `from_dict` class method which allows to convert a dict to this class
|
||||
It works great for simple case where values are simple types(str, int, etc). If however you need to process it you can override the `_from_dict` method.
|
||||
|
||||
** Important: Do not override the `from_dict` method as this one will handle error and display them nicely **
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def _from_dict(cls, args: dict):
|
||||
if "input1" in args:
|
||||
args["input1"] = MyInput1Model.from_dict(args["input1"])
|
||||
|
||||
return super()._from_dict(args)
|
||||
```
|
||||
When inheriting from `Model` it comes with a `from_dict` class method which allows to convert a dict to this class
|
||||
|
|
2
pylintrc
2
pylintrc
|
@ -192,7 +192,7 @@ max-nested-blocks=5
|
|||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=120
|
||||
max-line-length=140
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
from enum import Enum
|
||||
|
||||
import yaml
|
||||
import pytest
|
||||
|
||||
from aztk.core.models import Model, fields, ModelMergeStrategy, ListMergeStrategy
|
||||
from aztk.error import InvalidModelFieldError, AztkAttributeError
|
||||
|
||||
# pylint: disable=C1801
|
||||
|
||||
class UserState(Enum):
|
||||
Creating = "creating"
|
||||
Ready = "ready"
|
||||
Deleting = "deleting"
|
||||
|
||||
class UserInfo(Model):
|
||||
name = fields.String()
|
||||
age = fields.Integer()
|
||||
|
||||
class User(Model):
|
||||
info = fields.Model(UserInfo)
|
||||
enabled = fields.Boolean(default=True)
|
||||
state = fields.Enum(UserState, default=UserState.Ready)
|
||||
|
||||
|
||||
def test_models():
|
||||
user = User(
|
||||
info=UserInfo(
|
||||
name="Highlander",
|
||||
age=800,
|
||||
),
|
||||
enabled=False,
|
||||
state=UserState.Creating,
|
||||
)
|
||||
|
||||
assert user.info.name == "Highlander"
|
||||
assert user.info.age == 800
|
||||
assert user.enabled is False
|
||||
assert user.state == UserState.Creating
|
||||
|
||||
def test_inherited_models():
|
||||
class ServiceUser(User):
|
||||
service = fields.String()
|
||||
|
||||
user = ServiceUser(
|
||||
info=dict(
|
||||
name="Bob",
|
||||
age=59,
|
||||
),
|
||||
enabled=False,
|
||||
service="bus",
|
||||
)
|
||||
user.validate()
|
||||
|
||||
assert user.info.name == "Bob"
|
||||
assert user.info.age == 59
|
||||
assert user.enabled is False
|
||||
assert user.state == UserState.Ready
|
||||
assert user.service == "bus"
|
||||
|
||||
def test_raise_error_if_extra_parameters():
|
||||
class SimpleNameModel(Model):
|
||||
name = fields.String()
|
||||
|
||||
with pytest.raises(AztkAttributeError, match="SimpleNameModel doesn't have an attribute called abc"):
|
||||
SimpleNameModel(name="foo", abc="123")
|
||||
|
||||
def test_enum_invalid_type_raise_error():
|
||||
class SimpleStateModel(Model):
|
||||
state = fields.Enum(UserState)
|
||||
|
||||
|
||||
with pytest.raises(
|
||||
InvalidModelFieldError,
|
||||
match="SimpleStateModel state unknown is not a valid option. Use one of \\['creating', 'ready', 'deleting'\\]"):
|
||||
|
||||
obj = SimpleStateModel(state="unknown")
|
||||
obj.validate()
|
||||
|
||||
def test_enum_parse_string():
|
||||
class SimpleStateModel(Model):
|
||||
state = fields.Enum(UserState)
|
||||
|
||||
obj = SimpleStateModel(state="creating")
|
||||
obj.validate()
|
||||
|
||||
assert obj.state == UserState.Creating
|
||||
|
||||
|
||||
|
||||
def test_convert_nested_dict_to_model():
|
||||
user = User(
|
||||
info=dict(
|
||||
name="Highlander",
|
||||
age=800,
|
||||
),
|
||||
enabled=False,
|
||||
state="deleting",
|
||||
)
|
||||
assert isinstance(user.info, UserInfo)
|
||||
assert user.info.name == "Highlander"
|
||||
assert user.info.age == 800
|
||||
assert user.enabled is False
|
||||
assert user.state == UserState.Deleting
|
||||
|
||||
def test_raise_error_if_missing_required_field():
|
||||
class SimpleRequiredModel(Model):
|
||||
name = fields.String()
|
||||
|
||||
missing = SimpleRequiredModel()
|
||||
|
||||
with pytest.raises(InvalidModelFieldError, match="SimpleRequiredModel name is required"):
|
||||
missing.validate()
|
||||
|
||||
def test_raise_error_if_string_field_invalid_type():
|
||||
class SimpleStringModel(Model):
|
||||
name = fields.String()
|
||||
|
||||
missing = SimpleStringModel(name=123)
|
||||
|
||||
with pytest.raises(InvalidModelFieldError, match="SimpleStringModel name 123 should be a string"):
|
||||
missing.validate()
|
||||
|
||||
def test_raise_error_if_int_field_invalid_type():
|
||||
class SimpleIntegerModel(Model):
|
||||
age = fields.Integer()
|
||||
|
||||
missing = SimpleIntegerModel(age='123')
|
||||
|
||||
with pytest.raises(InvalidModelFieldError, match="SimpleIntegerModel age 123 should be an integer"):
|
||||
missing.validate()
|
||||
|
||||
def test_raise_error_if_bool_field_invalid_type():
|
||||
class SimpleBoolModel(Model):
|
||||
enabled = fields.Boolean()
|
||||
|
||||
missing = SimpleBoolModel(enabled="false")
|
||||
|
||||
with pytest.raises(InvalidModelFieldError, match="SimpleBoolModel enabled false should be a boolean"):
|
||||
missing.validate()
|
||||
|
||||
def test_merge_with_default_value():
|
||||
class SimpleMergeModel(Model):
|
||||
name = fields.String()
|
||||
enabled = fields.Boolean(default=True)
|
||||
|
||||
record1 = SimpleMergeModel(enabled=False)
|
||||
assert record1.enabled is False
|
||||
|
||||
record2 = SimpleMergeModel(name="foo")
|
||||
assert record2.enabled is True
|
||||
|
||||
record1.merge(record2)
|
||||
assert record1.name == 'foo'
|
||||
assert record1.enabled is False
|
||||
|
||||
|
||||
def test_merge_nested_model_merge_strategy():
|
||||
class ComplexModel(Model):
|
||||
model_id = fields.String()
|
||||
info = fields.Model(UserInfo, merge_strategy=ModelMergeStrategy.Merge)
|
||||
|
||||
obj1 = ComplexModel(
|
||||
info=dict(
|
||||
name="John",
|
||||
age=29,
|
||||
)
|
||||
)
|
||||
obj2 = ComplexModel(
|
||||
info=dict(
|
||||
age=38,
|
||||
)
|
||||
)
|
||||
|
||||
assert obj1.info.age == 29
|
||||
assert obj2.info.age == 38
|
||||
|
||||
obj1.merge(obj2)
|
||||
assert obj1.info.name == "John"
|
||||
assert obj1.info.age == 38
|
||||
|
||||
def test_merge_nested_model_override_strategy():
|
||||
class ComplexModel(Model):
|
||||
model_id = fields.String()
|
||||
info = fields.Model(UserInfo, merge_strategy=ModelMergeStrategy.Override)
|
||||
|
||||
obj1 = ComplexModel(
|
||||
info=dict(
|
||||
name="John",
|
||||
age=29,
|
||||
)
|
||||
)
|
||||
obj2 = ComplexModel(
|
||||
info=dict(
|
||||
age=38,
|
||||
)
|
||||
)
|
||||
|
||||
assert obj1.info.age == 29
|
||||
assert obj2.info.age == 38
|
||||
|
||||
obj1.merge(obj2)
|
||||
assert obj1.info.name is None
|
||||
assert obj1.info.age == 38
|
||||
|
||||
def test_list_field_convert_model_correctly():
|
||||
class UserList(Model):
|
||||
infos = fields.List(UserInfo)
|
||||
|
||||
obj = UserList(
|
||||
infos=[
|
||||
dict(
|
||||
name="John",
|
||||
age=29,
|
||||
)
|
||||
]
|
||||
)
|
||||
obj.validate()
|
||||
|
||||
assert len(obj.infos) == 1
|
||||
assert isinstance(obj.infos[0], UserInfo)
|
||||
assert obj.infos[0].name == "John"
|
||||
assert obj.infos[0].age == 29
|
||||
|
||||
def test_list_field_is_never_required():
|
||||
class UserList(Model):
|
||||
infos = fields.List(UserInfo)
|
||||
|
||||
obj = UserList()
|
||||
obj.validate()
|
||||
|
||||
assert isinstance(obj.infos, (list,))
|
||||
assert len(obj.infos) == 0
|
||||
|
||||
infos = obj.infos
|
||||
infos.append(UserInfo())
|
||||
assert len(obj.infos) == 1
|
||||
|
||||
obj2 = UserList(infos=None)
|
||||
assert isinstance(obj2.infos, (list,))
|
||||
assert len(obj2.infos) == 0
|
||||
|
||||
def test_list_field_ignore_none_entries():
|
||||
class UserList(Model):
|
||||
infos = fields.List(UserInfo)
|
||||
|
||||
obj = UserList(infos=[None, None])
|
||||
obj.validate()
|
||||
|
||||
assert isinstance(obj.infos, (list,))
|
||||
assert len(obj.infos) == 0
|
||||
|
||||
def test_merge_nested_model_append_strategy():
|
||||
class UserList(Model):
|
||||
infos = fields.List(UserInfo, merge_strategy=ListMergeStrategy.Append)
|
||||
|
||||
obj1 = UserList(
|
||||
infos=[
|
||||
dict(
|
||||
name="John",
|
||||
age=29,
|
||||
)
|
||||
]
|
||||
)
|
||||
obj2 = UserList(
|
||||
infos=[dict(
|
||||
name="Frank",
|
||||
age=38,
|
||||
)]
|
||||
)
|
||||
|
||||
assert len(obj1.infos) == 1
|
||||
assert len(obj2.infos) == 1
|
||||
assert obj1.infos[0].name == "John"
|
||||
assert obj1.infos[0].age == 29
|
||||
assert obj2.infos[0].name == "Frank"
|
||||
assert obj2.infos[0].age == 38
|
||||
|
||||
obj1.merge(obj2)
|
||||
assert len(obj1.infos) == 2
|
||||
assert obj1.infos[0].name == "John"
|
||||
assert obj1.infos[0].age == 29
|
||||
assert obj1.infos[1].name == "Frank"
|
||||
assert obj1.infos[1].age == 38
|
||||
|
||||
|
||||
def test_serialize_simple_model_to_yaml():
|
||||
info = UserInfo(name="John", age=29)
|
||||
output = yaml.dump(info)
|
||||
|
||||
assert output == "!!python/object:test_models.UserInfo {age: 29, name: John}\n"
|
||||
|
||||
info_parsed = yaml.load(output)
|
||||
|
||||
assert isinstance(info_parsed, UserInfo)
|
||||
assert info_parsed.name == "John"
|
||||
assert info_parsed.age == 29
|
||||
|
||||
def test_serialize_nested_model_to_yaml():
|
||||
user = User(
|
||||
info=dict(name="John", age=29),
|
||||
enabled=True,
|
||||
state=UserState.Deleting,
|
||||
)
|
||||
output = yaml.dump(user)
|
||||
|
||||
assert output == "!!python/object:test_models.User\nenabled: true\ninfo: {age: 29, name: John}\nstate: deleting\n"
|
||||
|
||||
user_parsed = yaml.load(output)
|
||||
|
||||
assert isinstance(user_parsed, User)
|
||||
assert isinstance(user_parsed.info, UserInfo)
|
||||
assert user_parsed.info.name == "John"
|
||||
assert user_parsed.info.age == 29
|
||||
assert user_parsed.state == UserState.Deleting
|
||||
assert user_parsed.enabled is True
|
|
@ -8,9 +8,8 @@ dir_path = os.path.dirname(os.path.realpath(__file__))
|
|||
fake_plugin_dir = os.path.join(dir_path, "fake_plugins")
|
||||
|
||||
|
||||
class RequiredArgPlugin(PluginConfiguration):
|
||||
def __init__(self, req_arg):
|
||||
super().__init__(name="required-arg")
|
||||
def RequiredArgPlugin(req_arg):
|
||||
return PluginConfiguration(name="required-arg")
|
||||
|
||||
|
||||
def test_missing_plugin():
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from aztk.error import AztkError
|
||||
from aztk.error import AztkError, AztkAttributeError
|
||||
from aztk.models.plugins.internal import PluginReference, PluginTarget, PluginTargetRole
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ def test_from_dict():
|
|||
|
||||
|
||||
def test_from_dict_invalid_param():
|
||||
with pytest.raises(AztkError):
|
||||
with pytest.raises(AztkAttributeError):
|
||||
PluginReference.from_dict(dict(
|
||||
name2="invalid"
|
||||
))
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
|
||||
from aztk.models import ClusterConfiguration, Toolkit, UserConfiguration
|
||||
from aztk.spark.models.plugins import JupyterPlugin, HDFSPlugin
|
||||
|
||||
def test_vm_count_deprecated():
|
||||
with pytest.warns(DeprecationWarning):
|
||||
config = ClusterConfiguration(vm_count=3)
|
||||
assert config.size == 3
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
config = ClusterConfiguration(vm_low_pri_count=10)
|
||||
assert config.size_low_priority == 10
|
||||
|
||||
def test_size_none():
|
||||
config = ClusterConfiguration(size=None, size_low_priority=2)
|
||||
assert config.size == 0
|
||||
assert config.size_low_priority == 2
|
||||
|
||||
|
||||
def test_size_low_priority_none():
|
||||
config = ClusterConfiguration(size=1, size_low_priority=None)
|
||||
assert config.size == 1
|
||||
assert config.size_low_priority == 0
|
||||
|
||||
|
||||
def test_cluster_configuration():
|
||||
data = {
|
||||
'toolkit': {
|
||||
'software': 'spark',
|
||||
'version': '2.3.0',
|
||||
'environment': 'anaconda'
|
||||
},
|
||||
'vm_size': 'standard_a2',
|
||||
'size': 2,
|
||||
'size_low_priority': 3,
|
||||
'subnet_id': '/subscriptions/21abd678-18c5-4660-9fdd-8c5ba6b6fe1f/resourceGroups/abc/providers/Microsoft.Network/virtualNetworks/prodtest5vnet/subnets/default',
|
||||
'plugins': [
|
||||
JupyterPlugin(),
|
||||
HDFSPlugin(),
|
||||
],
|
||||
'user_configuration': {'username': 'spark'}
|
||||
}
|
||||
|
||||
config = ClusterConfiguration.from_dict(data)
|
||||
|
||||
assert isinstance(config.toolkit, Toolkit)
|
||||
assert config.toolkit.software == 'spark'
|
||||
assert config.toolkit.version == '2.3.0'
|
||||
assert config.toolkit.environment == 'anaconda'
|
||||
assert config.size == 2
|
||||
assert config.size_low_priority == 3
|
||||
assert config.vm_size == 'standard_a2'
|
||||
assert config.subnet_id == '/subscriptions/21abd678-18c5-4660-9fdd-8c5ba6b6fe1f/resourceGroups/abc/providers/Microsoft.Network/virtualNetworks/prodtest5vnet/subnets/default'
|
||||
|
||||
assert isinstance(config.user_configuration, UserConfiguration)
|
||||
assert config.user_configuration.username == 'spark'
|
||||
|
||||
assert len(config.plugins) == 2
|
||||
assert config.plugins[0].name == 'jupyter'
|
||||
assert config.plugins[1].name == 'hdfs'
|
Загрузка…
Ссылка в новой задаче