[AIRFLOW-6572] Move AWS classes to providers.amazon.aws package (#7178)

* [AIP-21] Move contrib.hooks.aws_glue_catalog_hook airflow.contrib.hooks.aws_glue_catalog_hook

* [AIP-21] Move contrib.hooks.aws_logs_hook airflow.contrib.hooks.aws_logs_hook

* [AIP-21] Move contrib.hooks.emr_hook airflow.contrib.hooks.emr_hook

* [AIP-21] Move contrib.operators.ecs_operator airflow.contrib.operators.ecs_operator

* [AIP-21] Move contrib.operators.emr_add_steps_operator airflow.contrib.operators.emr_add_steps_operator

* [AIP-21] Move contrib.operators.emr_create_job_flow_operator airflow.contrib.operators.emr_create_job_flow_operator

* [AIP-21] Move contrib.operators.emr_terminate_job_flow_operator airflow.contrib.operators.emr_terminate_job_flow_operator

* [AIP-21] Move contrib.operators.s3_copy_object_operator airflow.contrib.operators.s3_copy_object_operator

* [AIP-21] Move contrib.operators.s3_delete_objects_operator airflow.contrib.operators.s3_delete_objects_operator

* [AIP-21] Move contrib.operators.s3_list_operator airflow.contrib.operators.s3_list_operator

* [AIP-21] Move contrib.operators.sagemaker_base_operator airflow.contrib.operators.sagemaker_base_operator

* [AIP-21] Move contrib.operators.sagemaker_endpoint_config_operator airflow.contrib.operators.sagemaker_endpoint_config_operator

* [AIP-21] Move contrib.operators.sagemaker_endpoint_operator airflow.contrib.operators.sagemaker_endpoint_operator

* [AIP-21] Move contrib.operators.sagemaker_model_operator airflow.contrib.operators.sagemaker_model_operator

* [AIP-21] Move contrib.operators.sagemaker_training_operator airflow.contrib.operators.sagemaker_training_operator

* [AIP-21] Move contrib.operators.sagemaker_transform_operator airflow.contrib.operators.sagemaker_transform_operator

* [AIP-21] Move contrib.operators.sagemaker_tuning_operator airflow.contrib.operators.sagemaker_tuning_operator

* [AIP-21] Move contrib.sensors.aws_glue_catalog_partition_sensor airflow.contrib.sensors.aws_glue_catalog_partition_sensor

* [AIP-21] Move contrib.sensors.emr_base_sensor airflow.contrib.sensors.emr_base_sensor

* [AIP-21] Move contrib.sensors.emr_job_flow_sensor airflow.contrib.sensors.emr_job_flow_sensor

* [AIP-21] Move contrib.sensors.emr_step_sensor airflow.contrib.sensors.emr_step_sensor

* [AIP-21] Move contrib.sensors.sagemaker_base_sensor airflow.contrib.sensors.sagemaker_base_sensor

* [AIP-21] Move contrib.sensors.sagemaker_endpoint_sensor airflow.contrib.sensors.sagemaker_endpoint_sensor

* [AIP-21] Move contrib.sensors.sagemaker_training_sensor airflow.contrib.sensors.sagemaker_training_sensor

* [AIP-21] Move contrib.sensors.sagemaker_transform_sensor airflow.contrib.sensors.sagemaker_transform_sensor

* [AIP-21] Move contrib.sensors.sagemaker_tuning_sensor airflow.contrib.sensors.sagemaker_tuning_sensor

* [AIP-21] Move operators.s3_file_transform_operator airflow.operators.s3_file_transform_operator

* [AIP-21] Move sensors.s3_key_sensor airflow.sensors.s3_key_sensor

* [AIP-21] Move sensors.s3_prefix_sensor airflow.sensors.s3_prefix_sensor

* [AIP-21] Move contrib.hooks.sagemaker_hook providers.amazon.aws.hooks.sagemaker
This commit is contained in:
Kamil Breguła 2020-01-17 09:39:04 +01:00 коммит произвёл GitHub
Родитель 50efda5c69
Коммит c319e81cae
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
98 изменённых файлов: 4018 добавлений и 2996 удалений

Просмотреть файл

@ -22,8 +22,8 @@ This is an example dag for a AWS EMR Pipeline with auto steps.
from datetime import timedelta
from airflow import DAG
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor
from airflow.utils.dates import days_ago
DEFAULT_ARGS = {

Просмотреть файл

@ -25,10 +25,10 @@ terminating the cluster.
from datetime import timedelta
from airflow import DAG
from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator
from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor
from airflow.utils.dates import days_ago
DEFAULT_ARGS = {

Просмотреть файл

@ -16,137 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`."""
"""
This module contains AWS Glue Catalog Hook
"""
from airflow.contrib.hooks.aws_hook import AwsHook
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook # noqa
class AwsGlueCatalogHook(AwsHook):
"""
Interact with AWS Glue Catalog
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:type aws_conn_id: str
:param region_name: aws region name (example: us-east-1)
:type region_name: str
"""
def __init__(self,
aws_conn_id='aws_default',
region_name=None,
*args,
**kwargs):
self.region_name = region_name
self.conn = None
super().__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
def get_conn(self):
"""
Returns glue connection object.
"""
self.conn = self.get_client_type('glue', self.region_name)
return self.conn
def get_partitions(self,
database_name,
table_name,
expression='',
page_size=None,
max_items=None):
"""
Retrieves the partition values for a table.
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param table_name: The name of the partitions' table.
:type table_name: str
:param expression: An expression filtering the partitions to be returned.
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:return: set of partition values where each value is a tuple since
a partition may be composed of multiple columns. For example:
``{('2018-01-01','1'), ('2018-01-01','2')}``
"""
config = {
'PageSize': page_size,
'MaxItems': max_items,
}
paginator = self.get_conn().get_paginator('get_partitions')
response = paginator.paginate(
DatabaseName=database_name,
TableName=table_name,
Expression=expression,
PaginationConfig=config
)
partitions = set()
for page in response:
for partition in page['Partitions']:
partitions.add(tuple(partition['Values']))
return partitions
def check_for_partition(self, database_name, table_name, expression):
"""
Checks whether a partition exists
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table @partition belongs to
:type table_name: str
:expression: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:type expression: str
:rtype: bool
>>> hook = AwsGlueCatalogHook()
>>> t = 'static_babynames_partitioned'
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
True
"""
partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
return bool(partitions)
def get_table(self, database_name, table_name):
"""
Get the information of the table
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table
:type table_name: str
:rtype: dict
>>> hook = AwsGlueCatalogHook()
>>> r = hook.get_table('db', 'table_foo')
>>> r['Name'] = 'table_foo'
"""
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name)
return result['Table']
def get_table_location(self, database_name, table_name):
"""
Get the physical location of the table
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table
:type table_name: str
:return: str
"""
table = self.get_table(database_name, table_name)
return table['StorageDescriptor']['Location']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,87 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`."""
"""
This module contains a hook (AwsLogsHook) with some very basic
functionality for interacting with AWS CloudWatch.
"""
import warnings
from airflow.contrib.hooks.aws_hook import AwsHook
# pylint: disable=unused-import
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook # noqa
class AwsLogsHook(AwsHook):
"""
Interact with AWS CloudWatch Logs
:param region_name: AWS Region Name (example: us-west-2)
:type region_name: str
"""
def __init__(self, region_name=None, *args, **kwargs):
self.region_name = region_name
super().__init__(*args, **kwargs)
def get_conn(self):
"""
Establish an AWS connection for retrieving logs.
:rtype: CloudWatchLogs.Client
"""
return self.get_client_type('logs', region_name=self.region_name)
def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start_from_head=True):
"""
A generator for log items in a single stream. This will yield all the
items that are available at the current moment.
:param log_group: The name of the log group.
:type log_group: str
:param log_stream_name: The name of the specific stream.
:type log_stream_name: str
:param start_time: The time stamp value to start reading the logs from (default: 0).
:type start_time: int
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:type skip: int
:param start_from_head: whether to start from the beginning (True) of the log or
at the end of the log (False).
:type start_from_head: bool
:rtype: dict
:return: | A CloudWatch log event with the following key-value pairs:
| 'timestamp' (int): The time in milliseconds of the event.
| 'message' (str): The log event data.
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
"""
next_token = None
event_count = 1
while event_count > 0:
if next_token is not None:
token_arg = {'nextToken': next_token}
else:
token_arg = {}
response = self.get_conn().get_log_events(logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg)
events = response['events']
event_count = len(events)
if event_count > skip:
events = events[skip:]
skip = 0
else:
skip = skip - event_count
events = []
yield from events
if 'nextForwardToken' in response:
next_token = response['nextForwardToken']
else:
return
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,65 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.hooks.emr import EmrHook # noqa
class EmrHook(AwsHook):
"""
Interact with AWS EMR. emr_conn_id is only necessary for using the
create_job_flow method.
"""
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
self.emr_conn_id = emr_conn_id
self.region_name = region_name
self.conn = None
super().__init__(*args, **kwargs)
def get_conn(self):
if not self.conn:
self.conn = self.get_client_type('emr', self.region_name)
return self.conn
def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
conn = self.get_conn()
response = conn.list_clusters(
ClusterStates=cluster_states
)
matching_clusters = list(
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
)
if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]['Id']
self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id)
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException('More than one cluster found for name %s', emr_cluster_name)
else:
self.log.info('No cluster found for name %s', emr_cluster_name)
return None
def create_job_flow(self, job_flow_overrides):
"""
Creates a job flow using the config from the EMR connection.
Keys of the json extra hash may have the arguments of the boto3
run_job_flow method.
Overrides for this config may be passed as the job_flow_overrides.
"""
if not self.emr_conn_id:
raise AirflowException('emr_conn_id must be present to use create_job_flow')
emr_conn = self.get_connection(self.emr_conn_id)
config = emr_conn.extra_dejson.copy()
config.update(job_flow_overrides)
response = self.get_conn().run_job_flow(**config)
return response
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,726 +16,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import collections
import os
import tarfile
import tempfile
import time
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`."""
import warnings
from botocore.exceptions import ClientError
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils import timezone
class LogState:
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
JOB_COMPLETE = 4
COMPLETE = 5
# Position is a tuple that includes the last read timestamp and the number of items that were read
# at that time. This is used to figure out which event to start with on the next read.
Position = collections.namedtuple('Position', ['timestamp', 'skip'])
def argmin(arr, f):
"""Return the index, i, in arr that minimizes f(arr[i])"""
m = None
i = None
for idx, item in enumerate(arr):
if item is not None:
if m is None or f(item) < m:
m = f(item)
i = idx
return i
def secondary_training_status_changed(current_job_description, prev_job_description):
"""
Returns true if training job's secondary status message has changed.
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
:type current_job_description: dict
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
:type prev_job_description: dict
:return: Whether the secondary status message of a training job changed or not.
"""
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
return False
prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \
if prev_job_description is not None else None
last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \
if prev_job_secondary_status_transitions is not None \
and len(prev_job_secondary_status_transitions) > 0 else ''
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
return message != last_message
def secondary_training_status_message(job_description, prev_description):
"""
Returns a string contains start time and the secondary training job status message.
:param job_description: Returned response from DescribeTrainingJob call
:type job_description: dict
:param prev_description: Previous job description from DescribeTrainingJob call
:type prev_description: dict
:return: Job status string to be printed.
"""
if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
or len(job_description.get('SecondaryStatusTransitions')) == 0:
return ''
prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
if prev_description is not None else None
prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
if prev_description_secondary_transitions is not None else 0
current_transitions = job_description['SecondaryStatusTransitions']
transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
current_transitions[prev_transitions_num - len(current_transitions):]
status_strs = []
for transition in transitions_to_print:
message = transition['StatusMessage']
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))
return '\n'.join(status_strs)
class SageMakerHook(AwsHook):
"""
Interact with Amazon SageMaker.
"""
non_terminal_states = {'InProgress', 'Stopping'}
endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating',
'RollingBack', 'Deleting'}
failed_states = {'Failed'}
def __init__(self,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
def tar_and_s3_upload(self, path, key, bucket):
"""
Tar the local file or directory and upload to s3
:param path: local file or directory
:type path: str
:param key: s3 key
:type key: str
:param bucket: s3 bucket
:type bucket: str
:return: None
"""
with tempfile.TemporaryFile() as temp_file:
if os.path.isdir(path):
files = [os.path.join(path, name) for name in os.listdir(path)]
else:
files = [path]
with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file:
for f in files:
tar_file.add(f, arcname=os.path.basename(f))
temp_file.seek(0)
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
def configure_s3_resources(self, config):
"""
Extract the S3 operations from the configuration and execute them.
:param config: config of SageMaker operation
:type config: dict
:rtype: dict
"""
s3_operations = config.pop('S3Operations', None)
if s3_operations is not None:
create_bucket_ops = s3_operations.get('S3CreateBucket', [])
upload_ops = s3_operations.get('S3Upload', [])
for op in create_bucket_ops:
self.s3_hook.create_bucket(bucket_name=op['Bucket'])
for op in upload_ops:
if op['Tar']:
self.tar_and_s3_upload(op['Path'], op['Key'],
op['Bucket'])
else:
self.s3_hook.load_file(op['Path'], op['Key'],
op['Bucket'])
def check_s3_url(self, s3url):
"""
Check if an S3 URL exists
:param s3url: S3 url
:type s3url: str
:rtype: bool
"""
bucket, key = S3Hook.parse_s3_url(s3url)
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
raise AirflowException(
"The input S3 Bucket {} does not exist ".format(bucket))
if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
and not self.s3_hook.check_for_prefix(
prefix=key, bucket_name=bucket, delimiter='/'):
# check if s3 key exists in the case user provides a single file
# or if s3 prefix exists in the case user provides multiple files in
# a prefix
raise AirflowException("The input S3 Key "
"or Prefix {} does not exist in the Bucket {}"
.format(s3url, bucket))
return True
def check_training_config(self, training_config):
"""
Check if a training configuration is valid
:param training_config: training_config
:type training_config: dict
:return: None
"""
if "InputDataConfig" in training_config:
for channel in training_config['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
def check_tuning_config(self, tuning_config):
"""
Check if a tuning configuration is valid
:param tuning_config: tuning_config
:type tuning_config: dict
:return: None
"""
for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
def get_conn(self):
"""
Establish an AWS connection for SageMaker
:rtype: :py:class:`SageMaker.Client`
"""
return self.get_client_type('sagemaker')
def get_log_conn(self):
"""
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_conn` instead.
"""
warnings.warn("Method `get_log_conn` has been deprecated. "
"Please use `airflow.contrib.hooks.AwsLogsHook.get_conn` instead.",
category=DeprecationWarning,
stacklevel=2)
return self.logs_hook.get_conn()
def log_stream(self, log_group, stream_name, start_time=0, skip=0):
"""
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.
"""
warnings.warn("Method `log_stream` has been deprecated. "
"Please use `airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.",
category=DeprecationWarning,
stacklevel=2)
return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
def multi_stream_iter(self, log_group, streams, positions=None):
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
:param log_group: The name of the log group.
:type log_group: str
:param streams: A list of the log stream names. The position of the stream in this list is
the stream number.
:type streams: list
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
:type positions: list
:return: A tuple of (stream number, cloudwatch log event).
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams]
events = []
for s in event_iters:
if not s:
events.append(None)
continue
try:
events.append(next(s))
except StopIteration:
events.append(None)
while any(events):
i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
yield (i, events[i])
try:
events[i] = next(event_iters[i])
except StopIteration:
events[i] = None
def create_training_job(self, config, wait_for_completion=True, print_log=True,
check_interval=30, max_ingestion_time=None):
"""
Create a training job
:param config: the config for training
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to training job creation
"""
self.check_training_config(config)
response = self.get_conn().create_training_job(**config)
if print_log:
self.check_training_status_with_log(config['TrainingJobName'],
self.non_terminal_states,
self.failed_states,
wait_for_completion,
check_interval, max_ingestion_time
)
elif wait_for_completion:
describe_response = self.check_status(config['TrainingJobName'],
'TrainingJobStatus',
self.describe_training_job,
check_interval, max_ingestion_time
)
billable_time = \
(describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \
describe_response['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
return response
def create_tuning_job(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create a tuning job
:param config: the config for tuning
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to tuning job creation
"""
self.check_tuning_config(config)
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
if wait_for_completion:
self.check_status(config['HyperParameterTuningJobName'],
'HyperParameterTuningJobStatus',
self.describe_tuning_job,
check_interval, max_ingestion_time
)
return response
def create_transform_job(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create a transform job
:param config: the config for transform job
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to transform job creation
"""
self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
response = self.get_conn().create_transform_job(**config)
if wait_for_completion:
self.check_status(config['TransformJobName'],
'TransformJobStatus',
self.describe_transform_job,
check_interval, max_ingestion_time
)
return response
def create_model(self, config):
"""
Create a model job
:param config: the config for model
:type config: dict
:return: A response to model creation
"""
return self.get_conn().create_model(**config)
def create_endpoint_config(self, config):
"""
Create an endpoint config
:param config: the config for endpoint-config
:type config: dict
:return: A response to endpoint config creation
"""
return self.get_conn().create_endpoint_config(**config)
def create_endpoint(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint creation
"""
response = self.get_conn().create_endpoint(**config)
if wait_for_completion:
self.check_status(config['EndpointName'],
'EndpointStatus',
self.describe_endpoint,
check_interval, max_ingestion_time,
non_terminal_states=self.endpoint_non_terminal_states
)
return response
def update_endpoint(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Update an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint update
"""
response = self.get_conn().update_endpoint(**config)
if wait_for_completion:
self.check_status(config['EndpointName'],
'EndpointStatus',
self.describe_endpoint,
check_interval, max_ingestion_time,
non_terminal_states=self.endpoint_non_terminal_states
)
return response
def describe_training_job(self, name):
"""
Return the training job info associated with the name
:param name: the name of the training job
:type name: str
:return: A dict contains all the training job info
"""
return self.get_conn().describe_training_job(TrainingJobName=name)
def describe_training_job_with_log(self, job_name, positions, stream_names,
instance_count, state, last_description,
last_describe_job_call):
"""
Return the training job info associated with job_name and print CloudWatch logs
"""
log_group = '/aws/sagemaker/TrainingJobs'
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
logs_conn = self.logs_hook.get_conn()
try:
streams = logs_conn.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + '/',
orderBy='LogStreamName',
limit=instance_count
)
stream_names = [s['logStreamName'] for s in streams['logStreams']]
positions.update([(s, Position(timestamp=0, skip=0))
for s in stream_names if s not in positions])
except logs_conn.exceptions.ResourceNotFoundException:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
pass
if len(stream_names) > 0:
for idx, event in self.multi_stream_iter(log_group, stream_names, positions):
self.log.info(event['message'])
ts, count = positions[stream_names[idx]]
if event['timestamp'] == ts:
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1)
if state == LogState.COMPLETE:
return state, last_description, last_describe_job_call
if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = self.describe_training_job(job_name)
last_describe_job_call = time.time()
if secondary_training_status_changed(description, last_description):
self.log.info(secondary_training_status_message(description, last_description))
last_description = description
status = description['TrainingJobStatus']
if status not in self.non_terminal_states:
state = LogState.JOB_COMPLETE
return state, last_description, last_describe_job_call
def describe_tuning_job(self, name):
"""
Return the tuning job info associated with the name
:param name: the name of the tuning job
:type name: str
:return: A dict contains all the tuning job info
"""
return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
def describe_model(self, name):
"""
Return the SageMaker model info associated with the name
:param name: the name of the SageMaker model
:type name: str
:return: A dict contains all the model info
"""
return self.get_conn().describe_model(ModelName=name)
def describe_transform_job(self, name):
"""
Return the transform job info associated with the name
:param name: the name of the transform job
:type name: str
:return: A dict contains all the transform job info
"""
return self.get_conn().describe_transform_job(TransformJobName=name)
def describe_endpoint_config(self, name):
"""
Return the endpoint config info associated with the name
:param name: the name of the endpoint config
:type name: str
:return: A dict contains all the endpoint config info
"""
return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
def describe_endpoint(self, name):
"""
:param name: the name of the endpoint
:type name: str
:return: A dict contains all the endpoint info
"""
return self.get_conn().describe_endpoint(EndpointName=name)
def check_status(self, job_name, key,
describe_function, check_interval,
max_ingestion_time,
non_terminal_states=None):
"""
Check status of a SageMaker job
:param job_name: name of the job to check status
:type job_name: str
:param key: the key of the response dict
that points to the state
:type key: str
:param describe_function: the function used to retrieve the status
:type describe_function: python callable
:param args: the arguments for the function
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:param non_terminal_states: the set of nonterminal states
:type non_terminal_states: set
:return: response of describe call after job is done
"""
if not non_terminal_states:
non_terminal_states = self.non_terminal_states
sec = 0
running = True
while running:
time.sleep(check_interval)
sec = sec + check_interval
try:
response = describe_function(job_name)
status = response[key]
self.log.info('Job still running for %s seconds... '
'current status is %s' % (sec, status))
except KeyError:
raise AirflowException('Could not get status of the SageMaker job')
except ClientError:
raise AirflowException('AWS request failed, check logs for more info')
if status in non_terminal_states:
running = True
elif status in self.failed_states:
raise AirflowException('SageMaker job failed because %s' % response['FailureReason'])
else:
running = False
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
self.log.info('SageMaker Job Compeleted')
response = describe_function(job_name)
return response
def check_training_status_with_log(self, job_name, non_terminal_states, failed_states,
wait_for_completion, check_interval, max_ingestion_time):
"""
Display the logs for a given training job, optionally tailing them until the
job is complete.
:param job_name: name of the training job to check status and display logs for
:type job_name: str
:param non_terminal_states: the set of non_terminal states
:type non_terminal_states: set
:param failed_states: the set of failed states
:type failed_states: set
:param wait_for_completion: Whether to keep looking for new log entries
until the job completes
:type wait_for_completion: bool
:param check_interval: The interval in seconds between polling for new log entries and job completion
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: None
"""
sec = 0
description = self.describe_training_job(job_name)
self.log.info(secondary_training_status_message(description, None))
instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
job_already_completed = status not in non_terminal_states
state = LogState.TAILING if wait_for_completion and not job_already_completed else LogState.COMPLETE
# The loop below implements a state machine that alternates between checking the job status and
# reading whatever is available in the logs at this point. Note, that if we were called with
# wait_for_completion == False, we never check the job status.
#
# If wait_for_completion == TRUE and job is not completed, the initial state is TAILING
# If wait_for_completion == FALSE, the initial state is COMPLETE
# (doesn't matter if the job really is complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that
# got to Cloudwatch after the job was marked complete.
last_describe_job_call = time.time()
last_description = description
while True:
time.sleep(check_interval)
sec = sec + check_interval
state, last_description, last_describe_job_call = \
self.describe_training_job_with_log(job_name, positions, stream_names,
instance_count, state, last_description,
last_describe_job_call)
if state == LogState.COMPLETE:
break
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
if wait_for_completion:
status = last_description['TrainingJobStatus']
if status in failed_states:
reason = last_description.get('FailureReason', '(No reason provided)')
raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason))
billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
* instance_count
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
# pylint: disable=unused-import
from airflow.providers.amazon.aws.hooks.sagemaker import ( # noqa
LogState, Position, SageMakerHook, argmin, secondary_training_status_changed,
secondary_training_status_message,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,221 +16,33 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import sys
from datetime import datetime
from typing import Optional
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.typing_compat import Protocol
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from typing_extensions import Protocol, runtime_checkable
from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSProtocol as NewECSProtocol # noqa
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`.",
DeprecationWarning, stacklevel=2
)
class ECSProtocol(Protocol):
def run_task(self, **kwargs):
...
def get_waiter(self, x: str):
...
def describe_tasks(self, cluster, tasks):
...
def stop_task(self, cluster, task, reason: str):
...
class ECSOperator(BaseOperator):
@runtime_checkable
class ECSProtocol(NewECSProtocol, Protocol):
"""
Execute a task on AWS EC2 Container Service
:param task_definition: the task definition name on EC2 Container Service
:type task_definition: str
:param cluster: the cluster name on EC2 Container Service
:type cluster: str
:param overrides: the same parameter that boto3 will receive (templated):
http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task
:type overrides: dict
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
:type aws_conn_id: str
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:type region_name: str
:param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
:type launch_type: str
:param group: the name of the task group associated with the task
:type group: str
:param placement_constraints: an array of placement constraint objects to use for
the task
:type placement_constraints: list
:param platform_version: the platform version on which your task is running
:type platform_version: str
:param network_configuration: the network configuration for the task
:type network_configuration: dict
:param tags: a dictionary of tags in the form of {'tagKey': 'tagValue'}.
:type tags: dict
:param awslogs_group: the CloudWatch group where your ECS container logs are stored.
Only required if you want logs to be shown in the Airflow UI after your job has
finished.
:type awslogs_group: str
:param awslogs_region: the region in which your CloudWatch logs are stored.
If None, this is the same as the `region_name` parameter. If that is also None,
this is the default AWS region based on your connection settings.
:type awslogs_region: str
:param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs.
This is usually based on some custom name combined with the name of the container.
Only required if you want logs to be shown in the Airflow UI after your job has
finished.
:type awslogs_stream_prefix: str
This class is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs.ECSProtocol`.
"""
ui_color = '#f0ede4'
client = None # type: Optional[ECSProtocol]
arn = None # type: Optional[str]
template_fields = ('overrides',)
# A Protocol cannot be instantiated
@apply_defaults
def __init__(self, task_definition, cluster, overrides,
aws_conn_id=None, region_name=None, launch_type='EC2',
group=None, placement_constraints=None, platform_version='LATEST',
network_configuration=None, tags=None, awslogs_group=None,
awslogs_region=None, awslogs_stream_prefix=None, **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.task_definition = task_definition
self.cluster = cluster
self.overrides = overrides
self.launch_type = launch_type
self.group = group
self.placement_constraints = placement_constraints
self.platform_version = platform_version
self.network_configuration = network_configuration
self.tags = tags
self.awslogs_group = awslogs_group
self.awslogs_stream_prefix = awslogs_stream_prefix
self.awslogs_region = awslogs_region
if self.awslogs_region is None:
self.awslogs_region = region_name
self.hook = self.get_hook()
def execute(self, context):
self.log.info(
'Running ECS Task - Task definition: %s - on cluster %s',
self.task_definition, self.cluster
def __new__(cls, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.amazon.aws.operators.ecs.ECSProtocol`.""",
DeprecationWarning,
stacklevel=2,
)
self.log.info('ECSOperator overrides: %s', self.overrides)
self.client = self.hook.get_client_type(
'ecs',
region_name=self.region_name
)
run_opts = {
'cluster': self.cluster,
'taskDefinition': self.task_definition,
'overrides': self.overrides,
'startedBy': self.owner,
'launchType': self.launch_type,
}
if self.launch_type == 'FARGATE':
run_opts['platformVersion'] = self.platform_version
if self.group is not None:
run_opts['group'] = self.group
if self.placement_constraints is not None:
run_opts['placementConstraints'] = self.placement_constraints
if self.network_configuration is not None:
run_opts['networkConfiguration'] = self.network_configuration
if self.tags is not None:
run_opts['tags'] = [{'key': k, 'value': v} for (k, v) in self.tags.items()]
response = self.client.run_task(**run_opts)
failures = response['failures']
if len(failures) > 0:
raise AirflowException(response)
self.log.info('ECS Task started: %s', response)
self.arn = response['tasks'][0]['taskArn']
self._wait_for_task_ended()
self._check_success_task()
self.log.info('ECS Task has been successfully executed: %s', response)
def _wait_for_task_ended(self):
waiter = self.client.get_waiter('tasks_stopped')
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(
cluster=self.cluster,
tasks=[self.arn]
)
def _check_success_task(self):
response = self.client.describe_tasks(
cluster=self.cluster,
tasks=[self.arn]
)
self.log.info('ECS Task stopped, check status: %s', response)
# Get logs from CloudWatch if the awslogs log driver was used
if self.awslogs_group and self.awslogs_stream_prefix:
self.log.info('ECS Task logs output:')
task_id = self.arn.split("/")[-1]
stream_name = "{}/{}".format(self.awslogs_stream_prefix, task_id)
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
self.log.info("[{}] {}".format(dt.isoformat(), event['message']))
if len(response.get('failures', [])) > 0:
raise AirflowException(response)
for task in response['tasks']:
# This is a `stoppedReason` that indicates a task has not
# successfully finished, but there is no other indication of failure
# in the response.
# See, https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html # noqa E501
if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.',
task.get('stoppedReason', '')):
raise AirflowException(
'The task was stopped because the host instance terminated: {}'.
format(task.get('stoppedReason', '')))
containers = task['containers']
for container in containers:
if container.get('lastStatus') == 'STOPPED' and \
container['exitCode'] != 0:
raise AirflowException(
'This task is not in success state {}'.format(task))
elif container.get('lastStatus') == 'PENDING':
raise AirflowException('This task is still pending {}'.format(task))
elif 'error' in container.get('reason', '').lower():
raise AirflowException(
'This containers encounter an error during launching : {}'.
format(container.get('reason', '').lower()))
def get_hook(self):
return AwsHook(
aws_conn_id=self.aws_conn_id
)
def get_logs_hook(self):
return AwsLogsHook(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region
)
def on_kill(self):
response = self.client.stop_task(
cluster=self.cluster,
task=self.arn,
reason='Task killed by the user')
self.log.info(response)

Просмотреть файл

@ -16,75 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.emr_hook import EmrHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`."""
import warnings
class EmrAddStepsOperator(BaseOperator):
"""
An operator that adds steps to an existing EMR job_flow.
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator # noqa
:param job_flow_id: id of the JobFlow to add steps to. (templated)
:type job_flow_id: Optional[str]
:param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing
job_flow_id. will search for id of JobFlow with matching name in one of the states in
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
:type job_flow_name: Optional[str]
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
(templated)
:type cluster_states: list
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
:param steps: boto3 style steps to be added to the jobflow. (templated)
:type steps: list
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
:type do_xcom_push: bool
"""
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
job_flow_id=None,
job_flow_name=None,
cluster_states=None,
aws_conn_id='aws_default',
steps=None,
*args, **kwargs):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
if not ((job_flow_id is None) ^ (job_flow_name is None)):
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
super().__init__(*args, **kwargs)
steps = steps or []
self.aws_conn_id = aws_conn_id
self.job_flow_id = job_flow_id
self.job_flow_name = job_flow_name
self.cluster_states = cluster_states
self.steps = steps
def execute(self, context):
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
emr = emr_hook.get_conn()
job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
self.cluster_states)
if not job_flow_id:
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
self.log.info('Adding steps to %s', job_flow_id)
response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=self.steps)
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('Adding steps failed: %s' % response)
else:
self.log.info('Steps %s added to JobFlow', response['StepIds'])
return response['StepIds']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,59 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.emr_hook import EmrHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`."""
import warnings
class EmrCreateJobFlowOperator(BaseOperator):
"""
Creates an EMR JobFlow, reading the config from the EMR connection.
A dictionary of JobFlow overrides can be passed that override
the config from the connection.
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator # noqa
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
:param emr_conn_id: emr connection to use
:type emr_conn_id: str
:param job_flow_overrides: boto3 style arguments to override
emr_connection extra. (templated)
:type job_flow_overrides: dict
"""
template_fields = ['job_flow_overrides']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
aws_conn_id='aws_default',
emr_conn_id='emr_default',
job_flow_overrides=None,
region_name=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
if job_flow_overrides is None:
job_flow_overrides = {}
self.job_flow_overrides = job_flow_overrides
self.region_name = region_name
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id,
emr_conn_id=self.emr_conn_id,
region_name=self.region_name)
self.log.info(
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',
self.aws_conn_id, self.emr_conn_id
)
response = emr.create_job_flow(self.job_flow_overrides)
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('JobFlow creation failed: %s' % response)
else:
self.log.info('JobFlow with id %s created', response['JobFlowId'])
return response['JobFlowId']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,42 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.emr_hook import EmrHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`."""
import warnings
class EmrTerminateJobFlowOperator(BaseOperator):
"""
Operator to terminate EMR JobFlows.
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator # noqa
:param job_flow_id: id of the JobFlow to terminate. (templated)
:type job_flow_id: str
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""
template_fields = ['job_flow_id']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
job_flow_id,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Terminating JobFlow %s', self.job_flow_id)
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('JobFlow termination failed: %s' % response)
else:
self.log.info('JobFlow with id %s terminated', self.job_flow_id)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,81 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`."""
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator # noqa
class S3CopyObjectOperator(BaseOperator):
"""
Creates a copy of an object that is already stored in S3.
Note: the S3 connection used here needs to have access to both
source and destination bucket/key.
:param source_bucket_key: The key of the source object. (templated)
It can be either full s3:// style url or relative path from root level.
When it's specified as a full s3:// url, please omit source_bucket_name.
:type source_bucket_key: str
:param dest_bucket_key: The key of the object to copy to. (templated)
The convention to specify `dest_bucket_key` is the same as `source_bucket_key`.
:type dest_bucket_key: str
:param source_bucket_name: Name of the S3 bucket where the source object is in. (templated)
It should be omitted when `source_bucket_key` is provided as a full s3:// url.
:type source_bucket_name: str
:param dest_bucket_name: Name of the S3 bucket to where the object is copied. (templated)
It should be omitted when `dest_bucket_key` is provided as a full s3:// url.
:type dest_bucket_name: str
:param source_version_id: Version ID of the source object (OPTIONAL)
:type source_version_id: str
:param aws_conn_id: Connection id of the S3 connection to use
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- False: do not validate SSL certificates. SSL will still be used,
but SSL certificates will not be
verified.
- path/to/cert/bundle.pem: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('source_bucket_key', 'dest_bucket_key',
'source_bucket_name', 'dest_bucket_name')
@apply_defaults
def __init__(
self,
source_bucket_key,
dest_bucket_key,
source_bucket_name=None,
dest_bucket_name=None,
source_version_id=None,
aws_conn_id='aws_default',
verify=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.source_bucket_key = source_bucket_key
self.dest_bucket_key = dest_bucket_key
self.source_bucket_name = source_bucket_name
self.dest_bucket_name = dest_bucket_name
self.source_version_id = source_version_id
self.aws_conn_id = aws_conn_id
self.verify = verify
def execute(self, context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key,
self.source_bucket_name, self.dest_bucket_name,
self.source_version_id)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,72 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`."""
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator # noqa
class S3DeleteObjectsOperator(BaseOperator):
"""
To enable users to delete single object or multiple objects from
a bucket using a single HTTP request.
Users may specify up to 1000 keys to delete.
:param bucket: Name of the bucket in which you are going to delete object(s). (templated)
:type bucket: str
:param keys: The key(s) to delete from S3 bucket. (templated)
When ``keys`` is a string, it's supposed to be the key name of
the single object to delete.
When ``keys`` is a list, it's supposed to be the list of the
keys to delete.
You may specify up to 1000 keys.
:type keys: str or list
:param aws_conn_id: Connection id of the S3 connection to use
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used,
but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('keys', 'bucket')
@apply_defaults
def __init__(
self,
bucket,
keys,
aws_conn_id='aws_default',
verify=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.bucket = bucket
self.keys = keys
self.aws_conn_id = aws_conn_id
self.verify = verify
def execute(self, context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
response = s3_hook.delete_objects(bucket=self.bucket, keys=self.keys)
deleted_keys = [x['Key'] for x in response.get("Deleted", [])]
self.log.info("Deleted: %s", deleted_keys)
if "Errors" in response:
errors_keys = [x['Key'] for x in response.get("Errors", [])]
raise AirflowException("Errors when deleting: {}".format(errors_keys))
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,84 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`."""
from typing import Iterable
import warnings
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator # noqa
class S3ListOperator(BaseOperator):
"""
List all objects from the bucket with the given string prefix in name.
This operator returns a python list with the name of objects which can be
used by `xcom` in the downstream task.
:param bucket: The S3 bucket where to find the objects. (templated)
:type bucket: str
:param prefix: Prefix string to filters the objects whose name begin with
such prefix. (templated)
:type prefix: str
:param delimiter: the delimiter marks key hierarchy. (templated)
:type delimiter: str
:param aws_conn_id: The connection ID to use when connecting to S3 storage.
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
**Example**:
The following operator would list all the files
(excluding subfolders) from the S3
``customers/2018/04/`` key in the ``data`` bucket. ::
s3_file = S3ListOperator(
task_id='list_3s_files',
bucket='data',
prefix='customers/2018/04/',
delimiter='/',
aws_conn_id='aws_customers_conn'
)
"""
template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str]
ui_color = '#ffd700'
@apply_defaults
def __init__(self,
bucket,
prefix='',
delimiter='',
aws_conn_id='aws_default',
verify=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.bucket = bucket
self.prefix = prefix
self.delimiter = delimiter
self.aws_conn_id = aws_conn_id
self.verify = verify
def execute(self, context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
self.log.info(
'Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)',
self.bucket, self.prefix, self.delimiter
)
return hook.list_keys(
bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -19,10 +19,10 @@
import warnings
from tempfile import NamedTemporaryFile
from airflow.contrib.operators.s3_list_operator import S3ListOperator
from airflow.exceptions import AirflowException
from airflow.gcp.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator
from airflow.utils.decorators import apply_defaults

Просмотреть файл

@ -16,86 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`."""
import json
from typing import Iterable
import warnings
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator # noqa
class SageMakerBaseOperator(BaseOperator):
"""
This is the base operator for all SageMaker operators.
:param config: The configuration necessary to start a training job (templated)
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""
template_fields = ['config']
template_ext = ()
ui_color = '#ededed'
integer_fields = [] # type: Iterable[Iterable[str]]
@apply_defaults
def __init__(self,
config,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
self.config = config
self.hook = None
def parse_integer(self, config, field):
if len(field) == 1:
if isinstance(config, list):
for sub_config in config:
self.parse_integer(sub_config, field)
return
head = field[0]
if head in config:
config[head] = int(config[head])
return
if isinstance(config, list):
for sub_config in config:
self.parse_integer(sub_config, field)
return
head, tail = field[0], field[1:]
if head in config:
self.parse_integer(config[head], tail)
return
def parse_config_integers(self):
# Parse the integer fields of training config to integers
# in case the config is rendered by Jinja and all fields are str
for field in self.integer_fields:
self.parse_integer(self.config, field)
def expand_role(self):
pass
def preprocess_config(self):
self.log.info(
'Preprocessing the config and doing required s3_operations'
)
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.hook.configure_s3_resources(self.config)
self.parse_config_integers()
self.expand_role()
self.log.info(
'After preprocessing the config is:\n {}'.format(
json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
)
def execute(self, context):
raise NotImplementedError('Please implement execute() in sub class!')
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,51 +16,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module is deprecated.
Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.
"""
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import ( # noqa
SageMakerEndpointConfigOperator,
)
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
"""
Create a SageMaker endpoint config.
This operator returns The ARN of the endpoint config created in Amazon SageMaker
:param config: The configuration necessary to create an endpoint config.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""
integer_fields = [
['ProductionVariants', 'InitialInstanceCount']
]
@apply_defaults
def __init__(self,
config,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
def execute(self, context):
self.preprocess_config()
self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
response = self.hook.create_endpoint_config(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(
'Sagemaker endpoint config creation failed: %s' % response)
else:
return {
'EndpointConfig': self.hook.describe_endpoint_config(
self.config['EndpointConfigName']
)
}
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,135 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator # noqa
class SageMakerEndpointOperator(SageMakerBaseOperator):
"""
Create a SageMaker endpoint.
This operator returns The ARN of the endpoint created in Amazon SageMaker
:param config:
The configuration necessary to create an endpoint.
If you need to create a SageMaker endpoint based on an existed
SageMaker model and an existed SageMaker endpoint config::
config = endpoint_configuration;
If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
config = {
'Model': model_configuration,
'EndpointConfig': endpoint_config_configuration,
'Endpoint': endpoint_configuration
}
For details of the configuration parameter of model_configuration see
:py:meth:`SageMaker.Client.create_model`
For details of the configuration parameter of endpoint_config_configuration see
:py:meth:`SageMaker.Client.create_endpoint_config`
For details of the configuration parameter of endpoint_configuration see
:py:meth:`SageMaker.Client.create_endpoint`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
waits before polling the status of the endpoint creation.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
:type max_ingestion_time: int
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
:type operation: str
"""
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
check_interval=30,
max_ingestion_time=None,
operation='create',
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.operation = operation.lower()
if self.operation not in ['create', 'update']:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
self.create_integer_fields()
def create_integer_fields(self):
if 'EndpointConfig' in self.config:
self.integer_fields = [
['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
]
def expand_role(self):
if 'Model' not in self.config:
return
hook = AwsHook(self.aws_conn_id)
config = self.config['Model']
if 'ExecutionRoleArn' in config:
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
def execute(self, context):
self.preprocess_config()
model_info = self.config.get('Model')
endpoint_config_info = self.config.get('EndpointConfig')
endpoint_info = self.config.get('Endpoint', self.config)
if model_info:
self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
self.hook.create_model(model_info)
if endpoint_config_info:
self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
self.hook.create_endpoint_config(endpoint_config_info)
if self.operation == 'create':
sagemaker_operation = self.hook.create_endpoint
log_str = 'Creating'
elif self.operation == 'update':
sagemaker_operation = self.hook.update_endpoint
log_str = 'Updating'
else:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(
'Sagemaker endpoint creation failed: %s' % response)
else:
return {
'EndpointConfig': self.hook.describe_endpoint_config(
endpoint_info['EndpointConfigName']
),
'Endpoint': self.hook.describe_endpoint(
endpoint_info['EndpointName']
)
}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,52 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator # noqa
class SageMakerModelOperator(SageMakerBaseOperator):
"""
Create a SageMaker model.
This operator returns The ARN of the model created in Amazon SageMaker
:param config: The configuration necessary to create a model.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""
@apply_defaults
def __init__(self,
config,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
def expand_role(self):
if 'ExecutionRoleArn' in self.config:
hook = AwsHook(self.aws_conn_id)
self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
def execute(self, context):
self.preprocess_config()
self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
response = self.hook.create_model(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker model creation failed: %s' % response)
else:
return {
'Model': self.hook.describe_model(
self.config['ModelName']
)
}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,83 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator # noqa
class SageMakerTrainingOperator(SageMakerBaseOperator):
"""
Initiate a SageMaker training job.
This operator returns The ARN of the training job created in Amazon SageMaker.
:param config: The configuration necessary to start a training job (templated).
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
that the operation waits to check the status of the training job.
:type wait_for_completion: bool
:param print_log: if the operator should print the cloudwatch log during training
:type print_log: bool
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the training job
:type check_interval: int
:param max_ingestion_time: If wait is set to True, the operation fails if the training job
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
the operation does not timeout.
:type max_ingestion_time: int
"""
integer_fields = [
['ResourceConfig', 'InstanceCount'],
['ResourceConfig', 'VolumeSizeInGB'],
['StoppingCondition', 'MaxRuntimeInSeconds']
]
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
print_log=True,
check_interval=30,
max_ingestion_time=None,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
def expand_role(self):
if 'RoleArn' in self.config:
hook = AwsHook(self.aws_conn_id)
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
def execute(self, context):
self.preprocess_config()
self.log.info('Creating SageMaker Training Job %s.', self.config['TrainingJobName'])
response = self.hook.create_training_job(
self.config,
wait_for_completion=self.wait_for_completion,
print_log=self.print_log,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Training Job creation failed: %s' % response)
else:
return {
'Training': self.hook.describe_training_job(
self.config['TrainingJobName']
)
}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,109 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator # noqa
class SageMakerTransformOperator(SageMakerBaseOperator):
"""
Initiate a SageMaker transform job.
This operator returns The ARN of the model created in Amazon SageMaker.
:param config: The configuration necessary to start a transform job (templated).
If you need to create a SageMaker transform job based on an existed SageMaker model::
config = transform_config
If you need to create both SageMaker model and SageMaker Transform job::
config = {
'Model': model_config,
'Transform': transform_config
}
For details of the configuration parameter of transform_config see
:py:meth:`SageMaker.Client.create_transform_job`
For details of the configuration parameter of model_config, See:
:py:meth:`SageMaker.Client.create_model`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Set to True to wait until the transform job finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the transform job.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
:type max_ingestion_time: int
"""
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
check_interval=30,
max_ingestion_time=None,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.create_integer_fields()
def create_integer_fields(self):
self.integer_fields = [
['Transform', 'TransformResources', 'InstanceCount'],
['Transform', 'MaxConcurrentTransforms'],
['Transform', 'MaxPayloadInMB']
]
if 'Transform' not in self.config:
for field in self.integer_fields:
field.pop(0)
def expand_role(self):
if 'Model' not in self.config:
return
config = self.config['Model']
if 'ExecutionRoleArn' in config:
hook = AwsHook(self.aws_conn_id)
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
def execute(self, context):
self.preprocess_config()
model_config = self.config.get('Model')
transform_config = self.config.get('Transform', self.config)
if model_config:
self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
self.hook.create_model(model_config)
self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
response = self.hook.create_transform_job(
transform_config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker transform Job creation failed: %s' % response)
else:
return {
'Model': self.hook.describe_model(
transform_config['ModelName']
),
'Transform': self.hook.describe_transform_job(
transform_config['TransformJobName']
)
}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,84 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`."""
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator # noqa
class SageMakerTuningOperator(SageMakerBaseOperator):
"""
Initiate a SageMaker hyperparameter tuning job.
This operator returns The ARN of the tuning job created in Amazon SageMaker.
:param config: The configuration necessary to start a tuning job (templated).
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Set to True to wait until the tuning job finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the tuning job.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, the operation fails
if the tuning job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
:type max_ingestion_time: int
"""
integer_fields = [
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds']
]
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
check_interval=30,
max_ingestion_time=None,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
def expand_role(self):
if 'TrainingJobDefinition' in self.config:
config = self.config['TrainingJobDefinition']
if 'RoleArn' in config:
hook = AwsHook(self.aws_conn_id)
config['RoleArn'] = hook.expand_role(config['RoleArn'])
def execute(self, context):
self.preprocess_config()
self.log.info(
'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
)
response = self.hook.create_tuning_job(
self.config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Tuning Job creation failed: %s' % response)
else:
return {
'Tuning': self.hook.describe_tuning_job(
self.config['HyperParameterTuningJobName']
)
}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,78 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`."""
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import AwsGlueCatalogPartitionSensor # noqa
class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
"""
Waits for a partition to show up in AWS Glue Catalog.
:param table_name: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:type table_name: str
:param expression: The partition clause to wait for. This is passed as
is to the AWS Glue Catalog API's get_partitions function,
and supports SQL like notation as in ``ds='2015-01-01'
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:type aws_conn_id: str
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
if not specified.
:type region_name: str
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param poke_interval: Time in seconds that the job should wait in
between each tries
:type poke_interval: int
"""
template_fields = ('database_name', 'table_name', 'expression',)
ui_color = '#C5CAE9'
@apply_defaults
def __init__(self,
table_name, expression="ds='{{ ds }}'",
aws_conn_id='aws_default',
region_name=None,
database_name='default',
poke_interval=60 * 3,
*args,
**kwargs):
super().__init__(
poke_interval=poke_interval, *args, **kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.table_name = table_name
self.expression = expression
self.database_name = database_name
def poke(self, context):
"""
Checks for existence of the partition in the AWS Glue Catalog table
"""
if '.' in self.table_name:
self.database_name, self.table_name = self.table_name.split('.')
self.log.info(
'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression
)
return self.get_hook().check_for_partition(
self.database_name, self.table_name, self.expression)
def get_hook(self):
"""
Gets the AwsGlueCatalogHook
"""
if not hasattr(self, 'hook'):
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
self.hook = AwsGlueCatalogHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
return self.hook
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,45 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`."""
import warnings
class EmrBaseSensor(BaseSensorOperator):
"""
Contains general sensor behavior for EMR.
Subclasses should implement get_emr_response() and state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE constants.
"""
ui_color = '#66c3ff'
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor # noqa
@apply_defaults
def __init__(
self,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
def poke(self, context):
response = self.get_emr_response()
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
self.log.info('Bad HTTP response: %s', response)
return False
state = self.state_from_response(response)
self.log.info('Job flow currently %s', state)
if state in self.NON_TERMINAL_STATES:
return False
if state in self.FAILED_STATE:
final_message = 'EMR job failed'
failure_message = self.failure_message_from_response(response)
if failure_message:
final_message += ' ' + failure_message
raise AirflowException(final_message)
return True
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,48 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.emr_hook import EmrHook
from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`."""
import warnings
class EmrJobFlowSensor(EmrBaseSensor):
"""
Asks for the state of the JobFlow until it reaches a terminal state.
If it fails the sensor errors, failing the task.
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor # noqa
:param job_flow_id: job_flow_id to check the state of
:type job_flow_id: str
"""
NON_TERMINAL_STATES = ['STARTING', 'BOOTSTRAPPING', 'RUNNING',
'WAITING', 'TERMINATING']
FAILED_STATE = ['TERMINATED_WITH_ERRORS']
template_fields = ['job_flow_id']
template_ext = ()
@apply_defaults
def __init__(self,
job_flow_id,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_flow_id = job_flow_id
def get_emr_response(self):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Poking cluster %s', self.job_flow_id)
return emr.describe_cluster(ClusterId=self.job_flow_id)
@staticmethod
def state_from_response(response):
return response['Cluster']['Status']['State']
@staticmethod
def failure_message_from_response(response):
state_change_reason = response['Cluster']['Status'].get('StateChangeReason')
if state_change_reason:
return 'for code: {} with message {}'.format(state_change_reason.get('Code', 'No code'),
state_change_reason.get('Message', 'Unknown'))
return None
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,52 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.emr_hook import EmrHook
from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`."""
import warnings
class EmrStepSensor(EmrBaseSensor):
"""
Asks for the state of the step until it reaches a terminal state.
If it fails the sensor errors, failing the task.
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor # noqa
:param job_flow_id: job_flow_id which contains the step check the state of
:type job_flow_id: str
:param step_id: step to check the state of
:type step_id: str
"""
NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE', 'CANCEL_PENDING']
FAILED_STATE = ['CANCELLED', 'FAILED', 'INTERRUPTED']
template_fields = ['job_flow_id', 'step_id']
template_ext = ()
@apply_defaults
def __init__(self,
job_flow_id,
step_id,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_flow_id = job_flow_id
self.step_id = step_id
def get_emr_response(self):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id)
return emr.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
@staticmethod
def state_from_response(response):
return response['Step']['Status']['State']
@staticmethod
def failure_message_from_response(response):
fail_details = response['Step']['Status'].get('FailureDetails')
if fail_details:
return 'for reason {} with message {} and log file {}'.format(fail_details.get('Reason'),
fail_details.get('Message'),
fail_details.get('LogFile'))
return None
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,59 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`."""
import warnings
class SageMakerBaseSensor(BaseSensorOperator):
"""
Contains general sensor behavior for SageMaker.
Subclasses should implement get_sagemaker_response()
and state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
"""
ui_color = '#ededed'
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor # noqa
@apply_defaults
def __init__(
self,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
def poke(self, context):
response = self.get_sagemaker_response()
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
self.log.info('Bad HTTP response: %s', response)
return False
state = self.state_from_response(response)
self.log.info('Job currently %s', state)
if state in self.non_terminal_states():
return False
if state in self.failed_states():
failed_reason = self.get_failed_reason_from_response(response)
raise AirflowException('Sagemaker job failed for the following reason: %s'
% failed_reason)
return True
def non_terminal_states(self):
raise NotImplementedError('Please implement non_terminal_states() in subclass')
def failed_states(self):
raise NotImplementedError('Please implement failed_states() in subclass')
def get_sagemaker_response(self):
raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
def get_failed_reason_from_response(self, response):
return 'Unknown'
def state_from_response(self, response):
raise NotImplementedError('Please implement state_from_response() in subclass')
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,46 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`."""
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor # noqa
class SageMakerEndpointSensor(SageMakerBaseSensor):
"""
Asks for the state of the endpoint state until it reaches a terminal state.
If it fails the sensor errors, the task fails.
:param job_name: job_name of the endpoint instance to check the state of
:type job_name: str
"""
template_fields = ['endpoint_name']
template_ext = ()
@apply_defaults
def __init__(self,
endpoint_name,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.endpoint_name = endpoint_name
def non_terminal_states(self):
return SageMakerHook.endpoint_non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
return sagemaker.describe_endpoint(self.endpoint_name)
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['EndpointStatus']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,87 +16,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_training`."""
import time
import warnings
from airflow.contrib.hooks.sagemaker_hook import LogState, SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.sagemaker_training import ( # noqa
SageMakerHook, SageMakerTrainingSensor,
)
class SageMakerTrainingSensor(SageMakerBaseSensor):
"""
Asks for the state of the training state until it reaches a terminal state.
If it fails the sensor errors, failing the task.
:param job_name: name of the SageMaker training job to check the state of
:type job_name: str
:param print_log: if the operator should print the cloudwatch log
:type print_log: bool
"""
template_fields = ['job_name']
template_ext = ()
@apply_defaults
def __init__(self,
job_name,
print_log=True,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_name = job_name
self.print_log = print_log
self.positions = {}
self.stream_names = []
self.instance_count = None
self.state = None
self.last_description = None
self.last_describe_job_call = None
self.log_resource_inited = False
def init_log_resource(self, hook):
description = hook.describe_training_job(self.job_name)
self.instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']
job_already_completed = status not in self.non_terminal_states()
self.state = LogState.TAILING if not job_already_completed else LogState.COMPLETE
self.last_description = description
self.last_describe_job_call = time.time()
self.log_resource_inited = True
def non_terminal_states(self):
return SageMakerHook.non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
if self.print_log:
if not self.log_resource_inited:
self.init_log_resource(sagemaker_hook)
self.state, self.last_description, self.last_describe_job_call = \
sagemaker_hook.describe_training_job_with_log(self.job_name,
self.positions, self.stream_names,
self.instance_count, self.state,
self.last_description,
self.last_describe_job_call)
else:
self.last_description = sagemaker_hook.describe_training_job(self.job_name)
status = self.state_from_response(self.last_description)
if status not in self.non_terminal_states() and status not in self.failed_states():
billable_time = \
(self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
self.last_description['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)
return self.last_description
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['TrainingJobStatus']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_training`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,47 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`."""
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor # noqa
class SageMakerTransformSensor(SageMakerBaseSensor):
"""
Asks for the state of the transform state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.
:param job_name: job_name of the transform job instance to check the state of
:type job_name: str
"""
template_fields = ['job_name']
template_ext = ()
@apply_defaults
def __init__(self,
job_name,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_name = job_name
def non_terminal_states(self):
return SageMakerHook.non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
return sagemaker.describe_transform_job(self.job_name)
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['TransformJobStatus']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,47 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`."""
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor # noqa
class SageMakerTuningSensor(SageMakerBaseSensor):
"""
Asks for the state of the tuning state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.
:param job_name: job_name of the tuning instance to check the state of
:type job_name: str
"""
template_fields = ['job_name']
template_ext = ()
@apply_defaults
def __init__(self,
job_name,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_name = job_name
def non_terminal_states(self):
return SageMakerHook.non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
return sagemaker.describe_tuning_job(self.job_name)
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['HyperParameterTuningJobStatus']
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,155 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`."""
import subprocess
import sys
from tempfile import NamedTemporaryFile
from typing import Optional, Union
import warnings
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
# pylint: disable=unused-import
from airflow.providers.amazon.aws.operators.s3_file_transform import S3FileTransformOperator # noqa
class S3FileTransformOperator(BaseOperator):
"""
Copies data from a source S3 location to a temporary location on the
local filesystem. Runs a transformation on this file as specified by
the transformation script and uploads the output to a destination S3
location.
The locations of the source and the destination files in the local
filesystem is provided as an first and second arguments to the
transformation script. The transformation script is expected to read the
data from source, transform it and write the output to the local
destination file. The operator then takes over control and uploads the
local destination file to S3.
S3 Select is also available to filter the source contents. Users can
omit the transformation script if S3 Select expression is specified.
:param source_s3_key: The key to be retrieved from S3. (templated)
:type source_s3_key: str
:param dest_s3_key: The key to be written from S3. (templated)
:type dest_s3_key: str
:param transform_script: location of the executable transformation script
:type transform_script: str
:param select_expression: S3 Select expression
:type select_expression: str
:param source_aws_conn_id: source s3 connection
:type source_aws_conn_id: str
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
This is also applicable to ``dest_verify``.
:type source_verify: bool or str
:param dest_aws_conn_id: destination s3 connection
:type dest_aws_conn_id: str
:param dest_verify: Whether or not to verify SSL certificates for S3 connection.
See: ``source_verify``
:type dest_verify: bool or str
:param replace: Replace dest S3 key if it already exists
:type replace: bool
"""
template_fields = ('source_s3_key', 'dest_s3_key')
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
source_s3_key: str,
dest_s3_key: str,
transform_script: Optional[str] = None,
select_expression=None,
source_aws_conn_id: str = 'aws_default',
source_verify: Optional[Union[bool, str]] = None,
dest_aws_conn_id: str = 'aws_default',
dest_verify: Optional[Union[bool, str]] = None,
replace: bool = False,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.source_s3_key = source_s3_key
self.source_aws_conn_id = source_aws_conn_id
self.source_verify = source_verify
self.dest_s3_key = dest_s3_key
self.dest_aws_conn_id = dest_aws_conn_id
self.dest_verify = dest_verify
self.replace = replace
self.transform_script = transform_script
self.select_expression = select_expression
self.output_encoding = sys.getdefaultencoding()
def execute(self, context):
if self.transform_script is None and self.select_expression is None:
raise AirflowException(
"Either transform_script or select_expression must be specified")
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify)
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
self.log.info("Downloading source S3 file %s", self.source_s3_key)
if not source_s3.check_for_key(self.source_s3_key):
raise AirflowException(
"The source key {0} does not exist".format(self.source_s3_key))
source_s3_key_object = source_s3.get_key(self.source_s3_key)
with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest:
self.log.info(
"Dumping S3 file %s contents to local file %s",
self.source_s3_key, f_source.name
)
if self.select_expression is not None:
content = source_s3.select_key(
key=self.source_s3_key,
expression=self.select_expression
)
f_source.write(content.encode("utf-8"))
else:
source_s3_key_object.download_fileobj(Fileobj=f_source)
f_source.flush()
if self.transform_script is not None:
process = subprocess.Popen(
[self.transform_script, f_source.name, f_dest.name],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True
)
self.log.info("Output:")
for line in iter(process.stdout.readline, b''):
self.log.info(line.decode(self.output_encoding).rstrip())
process.wait()
if process.returncode > 0:
raise AirflowException(
"Transform script failed: {0}".format(process.returncode)
)
else:
self.log.info(
"Transform script successful. Output temporarily located at %s",
f_dest.name
)
self.log.info("Uploading transformed file to S3")
f_dest.flush()
dest_s3.load_file(
filename=f_dest.name,
key=self.dest_s3_key,
replace=self.replace
)
self.log.info("Upload successful")
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
class EmrHook(AwsHook):
"""
Interact with AWS EMR. emr_conn_id is only necessary for using the
create_job_flow method.
"""
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
self.emr_conn_id = emr_conn_id
self.region_name = region_name
self.conn = None
super().__init__(*args, **kwargs)
def get_conn(self):
if not self.conn:
self.conn = self.get_client_type('emr', self.region_name)
return self.conn
def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
conn = self.get_conn()
response = conn.list_clusters(
ClusterStates=cluster_states
)
matching_clusters = list(
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
)
if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]['Id']
self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id)
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException('More than one cluster found for name %s', emr_cluster_name)
else:
self.log.info('No cluster found for name %s', emr_cluster_name)
return None
def create_job_flow(self, job_flow_overrides):
"""
Creates a job flow using the config from the EMR connection.
Keys of the json extra hash may have the arguments of the boto3
run_job_flow method.
Overrides for this config may be passed as the job_flow_overrides.
"""
if not self.emr_conn_id:
raise AirflowException('emr_conn_id must be present to use create_job_flow')
emr_conn = self.get_connection(self.emr_conn_id)
config = emr_conn.extra_dejson.copy()
config.update(job_flow_overrides)
response = self.get_conn().run_job_flow(**config)
return response

Просмотреть файл

@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains AWS Glue Catalog Hook
"""
from airflow.contrib.hooks.aws_hook import AwsHook
class AwsGlueCatalogHook(AwsHook):
"""
Interact with AWS Glue Catalog
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:type aws_conn_id: str
:param region_name: aws region name (example: us-east-1)
:type region_name: str
"""
def __init__(self,
aws_conn_id='aws_default',
region_name=None,
*args,
**kwargs):
self.region_name = region_name
self.conn = None
super().__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
def get_conn(self):
"""
Returns glue connection object.
"""
self.conn = self.get_client_type('glue', self.region_name)
return self.conn
def get_partitions(self,
database_name,
table_name,
expression='',
page_size=None,
max_items=None):
"""
Retrieves the partition values for a table.
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param table_name: The name of the partitions' table.
:type table_name: str
:param expression: An expression filtering the partitions to be returned.
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:return: set of partition values where each value is a tuple since
a partition may be composed of multiple columns. For example:
``{('2018-01-01','1'), ('2018-01-01','2')}``
"""
config = {
'PageSize': page_size,
'MaxItems': max_items,
}
paginator = self.get_conn().get_paginator('get_partitions')
response = paginator.paginate(
DatabaseName=database_name,
TableName=table_name,
Expression=expression,
PaginationConfig=config
)
partitions = set()
for page in response:
for partition in page['Partitions']:
partitions.add(tuple(partition['Values']))
return partitions
def check_for_partition(self, database_name, table_name, expression):
"""
Checks whether a partition exists
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table @partition belongs to
:type table_name: str
:expression: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:type expression: str
:rtype: bool
>>> hook = AwsGlueCatalogHook()
>>> t = 'static_babynames_partitioned'
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
True
"""
partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
return bool(partitions)
def get_table(self, database_name, table_name):
"""
Get the information of the table
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table
:type table_name: str
:rtype: dict
>>> hook = AwsGlueCatalogHook()
>>> r = hook.get_table('db', 'table_foo')
>>> r['Name'] = 'table_foo'
"""
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name)
return result['Table']
def get_table_location(self, database_name, table_name):
"""
Get the physical location of the table
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table
:type table_name: str
:return: str
"""
table = self.get_table(database_name, table_name)
return table['StorageDescriptor']['Location']

Просмотреть файл

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains a hook (AwsLogsHook) with some very basic
functionality for interacting with AWS CloudWatch.
"""
from airflow.contrib.hooks.aws_hook import AwsHook
class AwsLogsHook(AwsHook):
"""
Interact with AWS CloudWatch Logs
:param region_name: AWS Region Name (example: us-west-2)
:type region_name: str
"""
def __init__(self, region_name=None, *args, **kwargs):
self.region_name = region_name
super().__init__(*args, **kwargs)
def get_conn(self):
"""
Establish an AWS connection for retrieving logs.
:rtype: CloudWatchLogs.Client
"""
return self.get_client_type('logs', region_name=self.region_name)
def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start_from_head=True):
"""
A generator for log items in a single stream. This will yield all the
items that are available at the current moment.
:param log_group: The name of the log group.
:type log_group: str
:param log_stream_name: The name of the specific stream.
:type log_stream_name: str
:param start_time: The time stamp value to start reading the logs from (default: 0).
:type start_time: int
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:type skip: int
:param start_from_head: whether to start from the beginning (True) of the log or
at the end of the log (False).
:type start_from_head: bool
:rtype: dict
:return: | A CloudWatch log event with the following key-value pairs:
| 'timestamp' (int): The time in milliseconds of the event.
| 'message' (str): The log event data.
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
"""
next_token = None
event_count = 1
while event_count > 0:
if next_token is not None:
token_arg = {'nextToken': next_token}
else:
token_arg = {}
response = self.get_conn().get_log_events(logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg)
events = response['events']
event_count = len(events)
if event_count > skip:
events = events[skip:]
skip = 0
else:
skip = skip - event_count
events = []
yield from events
if 'nextForwardToken' in response:
next_token = response['nextForwardToken']
else:
return

Просмотреть файл

@ -0,0 +1,741 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import collections
import os
import tarfile
import tempfile
import time
import warnings
from botocore.exceptions import ClientError
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils import timezone
class LogState:
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
JOB_COMPLETE = 4
COMPLETE = 5
# Position is a tuple that includes the last read timestamp and the number of items that were read
# at that time. This is used to figure out which event to start with on the next read.
Position = collections.namedtuple('Position', ['timestamp', 'skip'])
def argmin(arr, f):
"""Return the index, i, in arr that minimizes f(arr[i])"""
m = None
i = None
for idx, item in enumerate(arr):
if item is not None:
if m is None or f(item) < m:
m = f(item)
i = idx
return i
def secondary_training_status_changed(current_job_description, prev_job_description):
"""
Returns true if training job's secondary status message has changed.
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
:type current_job_description: dict
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
:type prev_job_description: dict
:return: Whether the secondary status message of a training job changed or not.
"""
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
return False
prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \
if prev_job_description is not None else None
last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \
if prev_job_secondary_status_transitions is not None \
and len(prev_job_secondary_status_transitions) > 0 else ''
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
return message != last_message
def secondary_training_status_message(job_description, prev_description):
"""
Returns a string contains start time and the secondary training job status message.
:param job_description: Returned response from DescribeTrainingJob call
:type job_description: dict
:param prev_description: Previous job description from DescribeTrainingJob call
:type prev_description: dict
:return: Job status string to be printed.
"""
if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
or len(job_description.get('SecondaryStatusTransitions')) == 0:
return ''
prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
if prev_description is not None else None
prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
if prev_description_secondary_transitions is not None else 0
current_transitions = job_description['SecondaryStatusTransitions']
transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
current_transitions[prev_transitions_num - len(current_transitions):]
status_strs = []
for transition in transitions_to_print:
message = transition['StatusMessage']
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))
return '\n'.join(status_strs)
class SageMakerHook(AwsHook):
"""
Interact with Amazon SageMaker.
"""
non_terminal_states = {'InProgress', 'Stopping'}
endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating',
'RollingBack', 'Deleting'}
failed_states = {'Failed'}
def __init__(self,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
def tar_and_s3_upload(self, path, key, bucket):
"""
Tar the local file or directory and upload to s3
:param path: local file or directory
:type path: str
:param key: s3 key
:type key: str
:param bucket: s3 bucket
:type bucket: str
:return: None
"""
with tempfile.TemporaryFile() as temp_file:
if os.path.isdir(path):
files = [os.path.join(path, name) for name in os.listdir(path)]
else:
files = [path]
with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file:
for f in files:
tar_file.add(f, arcname=os.path.basename(f))
temp_file.seek(0)
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
def configure_s3_resources(self, config):
"""
Extract the S3 operations from the configuration and execute them.
:param config: config of SageMaker operation
:type config: dict
:rtype: dict
"""
s3_operations = config.pop('S3Operations', None)
if s3_operations is not None:
create_bucket_ops = s3_operations.get('S3CreateBucket', [])
upload_ops = s3_operations.get('S3Upload', [])
for op in create_bucket_ops:
self.s3_hook.create_bucket(bucket_name=op['Bucket'])
for op in upload_ops:
if op['Tar']:
self.tar_and_s3_upload(op['Path'], op['Key'],
op['Bucket'])
else:
self.s3_hook.load_file(op['Path'], op['Key'],
op['Bucket'])
def check_s3_url(self, s3url):
"""
Check if an S3 URL exists
:param s3url: S3 url
:type s3url: str
:rtype: bool
"""
bucket, key = S3Hook.parse_s3_url(s3url)
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
raise AirflowException(
"The input S3 Bucket {} does not exist ".format(bucket))
if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
and not self.s3_hook.check_for_prefix(
prefix=key, bucket_name=bucket, delimiter='/'):
# check if s3 key exists in the case user provides a single file
# or if s3 prefix exists in the case user provides multiple files in
# a prefix
raise AirflowException("The input S3 Key "
"or Prefix {} does not exist in the Bucket {}"
.format(s3url, bucket))
return True
def check_training_config(self, training_config):
"""
Check if a training configuration is valid
:param training_config: training_config
:type training_config: dict
:return: None
"""
if "InputDataConfig" in training_config:
for channel in training_config['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
def check_tuning_config(self, tuning_config):
"""
Check if a tuning configuration is valid
:param tuning_config: tuning_config
:type tuning_config: dict
:return: None
"""
for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
def get_conn(self):
"""
Establish an AWS connection for SageMaker
:rtype: :py:class:`SageMaker.Client`
"""
return self.get_client_type('sagemaker')
def get_log_conn(self):
"""
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_conn` instead.
"""
warnings.warn("Method `get_log_conn` has been deprecated. "
"Please use `airflow.contrib.hooks.AwsLogsHook.get_conn` instead.",
category=DeprecationWarning,
stacklevel=2)
return self.logs_hook.get_conn()
def log_stream(self, log_group, stream_name, start_time=0, skip=0):
"""
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.
"""
warnings.warn("Method `log_stream` has been deprecated. "
"Please use `airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.",
category=DeprecationWarning,
stacklevel=2)
return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
def multi_stream_iter(self, log_group, streams, positions=None):
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
:param log_group: The name of the log group.
:type log_group: str
:param streams: A list of the log stream names. The position of the stream in this list is
the stream number.
:type streams: list
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
:type positions: list
:return: A tuple of (stream number, cloudwatch log event).
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams]
events = []
for s in event_iters:
if not s:
events.append(None)
continue
try:
events.append(next(s))
except StopIteration:
events.append(None)
while any(events):
i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
yield (i, events[i])
try:
events[i] = next(event_iters[i])
except StopIteration:
events[i] = None
def create_training_job(self, config, wait_for_completion=True, print_log=True,
check_interval=30, max_ingestion_time=None):
"""
Create a training job
:param config: the config for training
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to training job creation
"""
self.check_training_config(config)
response = self.get_conn().create_training_job(**config)
if print_log:
self.check_training_status_with_log(config['TrainingJobName'],
self.non_terminal_states,
self.failed_states,
wait_for_completion,
check_interval, max_ingestion_time
)
elif wait_for_completion:
describe_response = self.check_status(config['TrainingJobName'],
'TrainingJobStatus',
self.describe_training_job,
check_interval, max_ingestion_time
)
billable_time = \
(describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \
describe_response['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
return response
def create_tuning_job(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create a tuning job
:param config: the config for tuning
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to tuning job creation
"""
self.check_tuning_config(config)
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
if wait_for_completion:
self.check_status(config['HyperParameterTuningJobName'],
'HyperParameterTuningJobStatus',
self.describe_tuning_job,
check_interval, max_ingestion_time
)
return response
def create_transform_job(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create a transform job
:param config: the config for transform job
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to transform job creation
"""
self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
response = self.get_conn().create_transform_job(**config)
if wait_for_completion:
self.check_status(config['TransformJobName'],
'TransformJobStatus',
self.describe_transform_job,
check_interval, max_ingestion_time
)
return response
def create_model(self, config):
"""
Create a model job
:param config: the config for model
:type config: dict
:return: A response to model creation
"""
return self.get_conn().create_model(**config)
def create_endpoint_config(self, config):
"""
Create an endpoint config
:param config: the config for endpoint-config
:type config: dict
:return: A response to endpoint config creation
"""
return self.get_conn().create_endpoint_config(**config)
def create_endpoint(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint creation
"""
response = self.get_conn().create_endpoint(**config)
if wait_for_completion:
self.check_status(config['EndpointName'],
'EndpointStatus',
self.describe_endpoint,
check_interval, max_ingestion_time,
non_terminal_states=self.endpoint_non_terminal_states
)
return response
def update_endpoint(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Update an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint update
"""
response = self.get_conn().update_endpoint(**config)
if wait_for_completion:
self.check_status(config['EndpointName'],
'EndpointStatus',
self.describe_endpoint,
check_interval, max_ingestion_time,
non_terminal_states=self.endpoint_non_terminal_states
)
return response
def describe_training_job(self, name):
"""
Return the training job info associated with the name
:param name: the name of the training job
:type name: str
:return: A dict contains all the training job info
"""
return self.get_conn().describe_training_job(TrainingJobName=name)
def describe_training_job_with_log(self, job_name, positions, stream_names,
instance_count, state, last_description,
last_describe_job_call):
"""
Return the training job info associated with job_name and print CloudWatch logs
"""
log_group = '/aws/sagemaker/TrainingJobs'
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
logs_conn = self.logs_hook.get_conn()
try:
streams = logs_conn.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + '/',
orderBy='LogStreamName',
limit=instance_count
)
stream_names = [s['logStreamName'] for s in streams['logStreams']]
positions.update([(s, Position(timestamp=0, skip=0))
for s in stream_names if s not in positions])
except logs_conn.exceptions.ResourceNotFoundException:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
pass
if len(stream_names) > 0:
for idx, event in self.multi_stream_iter(log_group, stream_names, positions):
self.log.info(event['message'])
ts, count = positions[stream_names[idx]]
if event['timestamp'] == ts:
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1)
if state == LogState.COMPLETE:
return state, last_description, last_describe_job_call
if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = self.describe_training_job(job_name)
last_describe_job_call = time.time()
if secondary_training_status_changed(description, last_description):
self.log.info(secondary_training_status_message(description, last_description))
last_description = description
status = description['TrainingJobStatus']
if status not in self.non_terminal_states:
state = LogState.JOB_COMPLETE
return state, last_description, last_describe_job_call
def describe_tuning_job(self, name):
"""
Return the tuning job info associated with the name
:param name: the name of the tuning job
:type name: str
:return: A dict contains all the tuning job info
"""
return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
def describe_model(self, name):
"""
Return the SageMaker model info associated with the name
:param name: the name of the SageMaker model
:type name: str
:return: A dict contains all the model info
"""
return self.get_conn().describe_model(ModelName=name)
def describe_transform_job(self, name):
"""
Return the transform job info associated with the name
:param name: the name of the transform job
:type name: str
:return: A dict contains all the transform job info
"""
return self.get_conn().describe_transform_job(TransformJobName=name)
def describe_endpoint_config(self, name):
"""
Return the endpoint config info associated with the name
:param name: the name of the endpoint config
:type name: str
:return: A dict contains all the endpoint config info
"""
return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
def describe_endpoint(self, name):
"""
:param name: the name of the endpoint
:type name: str
:return: A dict contains all the endpoint info
"""
return self.get_conn().describe_endpoint(EndpointName=name)
def check_status(self, job_name, key,
describe_function, check_interval,
max_ingestion_time,
non_terminal_states=None):
"""
Check status of a SageMaker job
:param job_name: name of the job to check status
:type job_name: str
:param key: the key of the response dict
that points to the state
:type key: str
:param describe_function: the function used to retrieve the status
:type describe_function: python callable
:param args: the arguments for the function
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:param non_terminal_states: the set of nonterminal states
:type non_terminal_states: set
:return: response of describe call after job is done
"""
if not non_terminal_states:
non_terminal_states = self.non_terminal_states
sec = 0
running = True
while running:
time.sleep(check_interval)
sec = sec + check_interval
try:
response = describe_function(job_name)
status = response[key]
self.log.info('Job still running for %s seconds... '
'current status is %s' % (sec, status))
except KeyError:
raise AirflowException('Could not get status of the SageMaker job')
except ClientError:
raise AirflowException('AWS request failed, check logs for more info')
if status in non_terminal_states:
running = True
elif status in self.failed_states:
raise AirflowException('SageMaker job failed because %s' % response['FailureReason'])
else:
running = False
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
self.log.info('SageMaker Job Compeleted')
response = describe_function(job_name)
return response
def check_training_status_with_log(self, job_name, non_terminal_states, failed_states,
wait_for_completion, check_interval, max_ingestion_time):
"""
Display the logs for a given training job, optionally tailing them until the
job is complete.
:param job_name: name of the training job to check status and display logs for
:type job_name: str
:param non_terminal_states: the set of non_terminal states
:type non_terminal_states: set
:param failed_states: the set of failed states
:type failed_states: set
:param wait_for_completion: Whether to keep looking for new log entries
until the job completes
:type wait_for_completion: bool
:param check_interval: The interval in seconds between polling for new log entries and job completion
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: None
"""
sec = 0
description = self.describe_training_job(job_name)
self.log.info(secondary_training_status_message(description, None))
instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
job_already_completed = status not in non_terminal_states
state = LogState.TAILING if wait_for_completion and not job_already_completed else LogState.COMPLETE
# The loop below implements a state machine that alternates between checking the job status and
# reading whatever is available in the logs at this point. Note, that if we were called with
# wait_for_completion == False, we never check the job status.
#
# If wait_for_completion == TRUE and job is not completed, the initial state is TAILING
# If wait_for_completion == FALSE, the initial state is COMPLETE
# (doesn't matter if the job really is complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that
# got to Cloudwatch after the job was marked complete.
last_describe_job_call = time.time()
last_description = description
while True:
time.sleep(check_interval)
sec = sec + check_interval
state, last_description, last_describe_job_call = \
self.describe_training_job_with_log(job_name, positions, stream_names,
instance_count, state, last_description,
last_describe_job_call)
if state == LogState.COMPLETE:
break
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
if wait_for_completion:
status = last_description['TrainingJobStatus']
if status in failed_states:
reason = last_description.get('FailureReason', '(No reason provided)')
raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason))
billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
* instance_count
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))

Просмотреть файл

@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import sys
from datetime import datetime
from typing import Optional
from typing_extensions import runtime_checkable
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.typing_compat import Protocol
from airflow.utils.decorators import apply_defaults
@runtime_checkable
class ECSProtocol(Protocol):
def run_task(self, **kwargs):
...
def get_waiter(self, x: str):
...
def describe_tasks(self, cluster, tasks):
...
def stop_task(self, cluster, task, reason: str):
...
class ECSOperator(BaseOperator):
"""
Execute a task on AWS EC2 Container Service
:param task_definition: the task definition name on EC2 Container Service
:type task_definition: str
:param cluster: the cluster name on EC2 Container Service
:type cluster: str
:param overrides: the same parameter that boto3 will receive (templated):
http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task
:type overrides: dict
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
:type aws_conn_id: str
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:type region_name: str
:param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
:type launch_type: str
:param group: the name of the task group associated with the task
:type group: str
:param placement_constraints: an array of placement constraint objects to use for
the task
:type placement_constraints: list
:param platform_version: the platform version on which your task is running
:type platform_version: str
:param network_configuration: the network configuration for the task
:type network_configuration: dict
:param tags: a dictionary of tags in the form of {'tagKey': 'tagValue'}.
:type tags: dict
:param awslogs_group: the CloudWatch group where your ECS container logs are stored.
Only required if you want logs to be shown in the Airflow UI after your job has
finished.
:type awslogs_group: str
:param awslogs_region: the region in which your CloudWatch logs are stored.
If None, this is the same as the `region_name` parameter. If that is also None,
this is the default AWS region based on your connection settings.
:type awslogs_region: str
:param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs.
This is usually based on some custom name combined with the name of the container.
Only required if you want logs to be shown in the Airflow UI after your job has
finished.
:type awslogs_stream_prefix: str
"""
ui_color = '#f0ede4'
client = None # type: Optional[ECSProtocol]
arn = None # type: Optional[str]
template_fields = ('overrides',)
@apply_defaults
def __init__(self, task_definition, cluster, overrides,
aws_conn_id=None, region_name=None, launch_type='EC2',
group=None, placement_constraints=None, platform_version='LATEST',
network_configuration=None, tags=None, awslogs_group=None,
awslogs_region=None, awslogs_stream_prefix=None, **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.task_definition = task_definition
self.cluster = cluster
self.overrides = overrides
self.launch_type = launch_type
self.group = group
self.placement_constraints = placement_constraints
self.platform_version = platform_version
self.network_configuration = network_configuration
self.tags = tags
self.awslogs_group = awslogs_group
self.awslogs_stream_prefix = awslogs_stream_prefix
self.awslogs_region = awslogs_region
if self.awslogs_region is None:
self.awslogs_region = region_name
self.hook = self.get_hook()
def execute(self, context):
self.log.info(
'Running ECS Task - Task definition: %s - on cluster %s',
self.task_definition, self.cluster
)
self.log.info('ECSOperator overrides: %s', self.overrides)
self.client = self.hook.get_client_type(
'ecs',
region_name=self.region_name
)
run_opts = {
'cluster': self.cluster,
'taskDefinition': self.task_definition,
'overrides': self.overrides,
'startedBy': self.owner,
'launchType': self.launch_type,
}
if self.launch_type == 'FARGATE':
run_opts['platformVersion'] = self.platform_version
if self.group is not None:
run_opts['group'] = self.group
if self.placement_constraints is not None:
run_opts['placementConstraints'] = self.placement_constraints
if self.network_configuration is not None:
run_opts['networkConfiguration'] = self.network_configuration
if self.tags is not None:
run_opts['tags'] = [{'key': k, 'value': v} for (k, v) in self.tags.items()]
response = self.client.run_task(**run_opts)
failures = response['failures']
if len(failures) > 0:
raise AirflowException(response)
self.log.info('ECS Task started: %s', response)
self.arn = response['tasks'][0]['taskArn']
self._wait_for_task_ended()
self._check_success_task()
self.log.info('ECS Task has been successfully executed: %s', response)
def _wait_for_task_ended(self):
waiter = self.client.get_waiter('tasks_stopped')
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(
cluster=self.cluster,
tasks=[self.arn]
)
def _check_success_task(self):
response = self.client.describe_tasks(
cluster=self.cluster,
tasks=[self.arn]
)
self.log.info('ECS Task stopped, check status: %s', response)
# Get logs from CloudWatch if the awslogs log driver was used
if self.awslogs_group and self.awslogs_stream_prefix:
self.log.info('ECS Task logs output:')
task_id = self.arn.split("/")[-1]
stream_name = "{}/{}".format(self.awslogs_stream_prefix, task_id)
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
self.log.info("[{}] {}".format(dt.isoformat(), event['message']))
if len(response.get('failures', [])) > 0:
raise AirflowException(response)
for task in response['tasks']:
# This is a `stoppedReason` that indicates a task has not
# successfully finished, but there is no other indication of failure
# in the response.
# See, https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html # noqa E501
if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.',
task.get('stoppedReason', '')):
raise AirflowException(
'The task was stopped because the host instance terminated: {}'.
format(task.get('stoppedReason', '')))
containers = task['containers']
for container in containers:
if container.get('lastStatus') == 'STOPPED' and \
container['exitCode'] != 0:
raise AirflowException(
'This task is not in success state {}'.format(task))
elif container.get('lastStatus') == 'PENDING':
raise AirflowException('This task is still pending {}'.format(task))
elif 'error' in container.get('reason', '').lower():
raise AirflowException(
'This containers encounter an error during launching : {}'.
format(container.get('reason', '').lower()))
def get_hook(self):
return AwsHook(
aws_conn_id=self.aws_conn_id
)
def get_logs_hook(self):
return AwsLogsHook(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region
)
def on_kill(self):
response = self.client.stop_task(
cluster=self.cluster,
task=self.arn,
reason='Task killed by the user')
self.log.info(response)

Просмотреть файл

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.utils.decorators import apply_defaults
class EmrAddStepsOperator(BaseOperator):
"""
An operator that adds steps to an existing EMR job_flow.
:param job_flow_id: id of the JobFlow to add steps to. (templated)
:type job_flow_id: Optional[str]
:param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing
job_flow_id. will search for id of JobFlow with matching name in one of the states in
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
:type job_flow_name: Optional[str]
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
(templated)
:type cluster_states: list
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
:param steps: boto3 style steps to be added to the jobflow. (templated)
:type steps: list
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
:type do_xcom_push: bool
"""
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
job_flow_id=None,
job_flow_name=None,
cluster_states=None,
aws_conn_id='aws_default',
steps=None,
*args, **kwargs):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
if not ((job_flow_id is None) ^ (job_flow_name is None)):
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
super().__init__(*args, **kwargs)
steps = steps or []
self.aws_conn_id = aws_conn_id
self.job_flow_id = job_flow_id
self.job_flow_name = job_flow_name
self.cluster_states = cluster_states
self.steps = steps
def execute(self, context):
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
emr = emr_hook.get_conn()
job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
self.cluster_states)
if not job_flow_id:
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
self.log.info('Adding steps to %s', job_flow_id)
response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=self.steps)
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('Adding steps failed: %s' % response)
else:
self.log.info('Steps %s added to JobFlow', response['StepIds'])
return response['StepIds']

Просмотреть файл

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.utils.decorators import apply_defaults
class EmrCreateJobFlowOperator(BaseOperator):
"""
Creates an EMR JobFlow, reading the config from the EMR connection.
A dictionary of JobFlow overrides can be passed that override
the config from the connection.
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
:param emr_conn_id: emr connection to use
:type emr_conn_id: str
:param job_flow_overrides: boto3 style arguments to override
emr_connection extra. (templated)
:type job_flow_overrides: dict
"""
template_fields = ['job_flow_overrides']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
aws_conn_id='aws_default',
emr_conn_id='emr_default',
job_flow_overrides=None,
region_name=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
if job_flow_overrides is None:
job_flow_overrides = {}
self.job_flow_overrides = job_flow_overrides
self.region_name = region_name
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id,
emr_conn_id=self.emr_conn_id,
region_name=self.region_name)
self.log.info(
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',
self.aws_conn_id, self.emr_conn_id
)
response = emr.create_job_flow(self.job_flow_overrides)
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('JobFlow creation failed: %s' % response)
else:
self.log.info('JobFlow with id %s created', response['JobFlowId'])
return response['JobFlowId']

Просмотреть файл

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.utils.decorators import apply_defaults
class EmrTerminateJobFlowOperator(BaseOperator):
"""
Operator to terminate EMR JobFlows.
:param job_flow_id: id of the JobFlow to terminate. (templated)
:type job_flow_id: str
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""
template_fields = ['job_flow_id']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
job_flow_id,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Terminating JobFlow %s', self.job_flow_id)
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('JobFlow termination failed: %s' % response)
else:
self.log.info('JobFlow with id %s terminated', self.job_flow_id)

Просмотреть файл

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
class S3CopyObjectOperator(BaseOperator):
"""
Creates a copy of an object that is already stored in S3.
Note: the S3 connection used here needs to have access to both
source and destination bucket/key.
:param source_bucket_key: The key of the source object. (templated)
It can be either full s3:// style url or relative path from root level.
When it's specified as a full s3:// url, please omit source_bucket_name.
:type source_bucket_key: str
:param dest_bucket_key: The key of the object to copy to. (templated)
The convention to specify `dest_bucket_key` is the same as `source_bucket_key`.
:type dest_bucket_key: str
:param source_bucket_name: Name of the S3 bucket where the source object is in. (templated)
It should be omitted when `source_bucket_key` is provided as a full s3:// url.
:type source_bucket_name: str
:param dest_bucket_name: Name of the S3 bucket to where the object is copied. (templated)
It should be omitted when `dest_bucket_key` is provided as a full s3:// url.
:type dest_bucket_name: str
:param source_version_id: Version ID of the source object (OPTIONAL)
:type source_version_id: str
:param aws_conn_id: Connection id of the S3 connection to use
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- False: do not validate SSL certificates. SSL will still be used,
but SSL certificates will not be
verified.
- path/to/cert/bundle.pem: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('source_bucket_key', 'dest_bucket_key',
'source_bucket_name', 'dest_bucket_name')
@apply_defaults
def __init__(
self,
source_bucket_key,
dest_bucket_key,
source_bucket_name=None,
dest_bucket_name=None,
source_version_id=None,
aws_conn_id='aws_default',
verify=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.source_bucket_key = source_bucket_key
self.dest_bucket_key = dest_bucket_key
self.source_bucket_name = source_bucket_name
self.dest_bucket_name = dest_bucket_name
self.source_version_id = source_version_id
self.aws_conn_id = aws_conn_id
self.verify = verify
def execute(self, context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key,
self.source_bucket_name, self.dest_bucket_name,
self.source_version_id)

Просмотреть файл

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
class S3DeleteObjectsOperator(BaseOperator):
"""
To enable users to delete single object or multiple objects from
a bucket using a single HTTP request.
Users may specify up to 1000 keys to delete.
:param bucket: Name of the bucket in which you are going to delete object(s). (templated)
:type bucket: str
:param keys: The key(s) to delete from S3 bucket. (templated)
When ``keys`` is a string, it's supposed to be the key name of
the single object to delete.
When ``keys`` is a list, it's supposed to be the list of the
keys to delete.
You may specify up to 1000 keys.
:type keys: str or list
:param aws_conn_id: Connection id of the S3 connection to use
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used,
but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('keys', 'bucket')
@apply_defaults
def __init__(
self,
bucket,
keys,
aws_conn_id='aws_default',
verify=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.bucket = bucket
self.keys = keys
self.aws_conn_id = aws_conn_id
self.verify = verify
def execute(self, context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
response = s3_hook.delete_objects(bucket=self.bucket, keys=self.keys)
deleted_keys = [x['Key'] for x in response.get("Deleted", [])]
self.log.info("Deleted: %s", deleted_keys)
if "Errors" in response:
errors_keys = [x['Key'] for x in response.get("Errors", [])]
raise AirflowException("Errors when deleting: {}".format(errors_keys))

Просмотреть файл

@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import subprocess
import sys
from tempfile import NamedTemporaryFile
from typing import Optional, Union
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
class S3FileTransformOperator(BaseOperator):
"""
Copies data from a source S3 location to a temporary location on the
local filesystem. Runs a transformation on this file as specified by
the transformation script and uploads the output to a destination S3
location.
The locations of the source and the destination files in the local
filesystem is provided as an first and second arguments to the
transformation script. The transformation script is expected to read the
data from source, transform it and write the output to the local
destination file. The operator then takes over control and uploads the
local destination file to S3.
S3 Select is also available to filter the source contents. Users can
omit the transformation script if S3 Select expression is specified.
:param source_s3_key: The key to be retrieved from S3. (templated)
:type source_s3_key: str
:param dest_s3_key: The key to be written from S3. (templated)
:type dest_s3_key: str
:param transform_script: location of the executable transformation script
:type transform_script: str
:param select_expression: S3 Select expression
:type select_expression: str
:param source_aws_conn_id: source s3 connection
:type source_aws_conn_id: str
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
This is also applicable to ``dest_verify``.
:type source_verify: bool or str
:param dest_aws_conn_id: destination s3 connection
:type dest_aws_conn_id: str
:param dest_verify: Whether or not to verify SSL certificates for S3 connection.
See: ``source_verify``
:type dest_verify: bool or str
:param replace: Replace dest S3 key if it already exists
:type replace: bool
"""
template_fields = ('source_s3_key', 'dest_s3_key')
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
source_s3_key: str,
dest_s3_key: str,
transform_script: Optional[str] = None,
select_expression=None,
source_aws_conn_id: str = 'aws_default',
source_verify: Optional[Union[bool, str]] = None,
dest_aws_conn_id: str = 'aws_default',
dest_verify: Optional[Union[bool, str]] = None,
replace: bool = False,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.source_s3_key = source_s3_key
self.source_aws_conn_id = source_aws_conn_id
self.source_verify = source_verify
self.dest_s3_key = dest_s3_key
self.dest_aws_conn_id = dest_aws_conn_id
self.dest_verify = dest_verify
self.replace = replace
self.transform_script = transform_script
self.select_expression = select_expression
self.output_encoding = sys.getdefaultencoding()
def execute(self, context):
if self.transform_script is None and self.select_expression is None:
raise AirflowException(
"Either transform_script or select_expression must be specified")
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify)
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
self.log.info("Downloading source S3 file %s", self.source_s3_key)
if not source_s3.check_for_key(self.source_s3_key):
raise AirflowException(
"The source key {0} does not exist".format(self.source_s3_key))
source_s3_key_object = source_s3.get_key(self.source_s3_key)
with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest:
self.log.info(
"Dumping S3 file %s contents to local file %s",
self.source_s3_key, f_source.name
)
if self.select_expression is not None:
content = source_s3.select_key(
key=self.source_s3_key,
expression=self.select_expression
)
f_source.write(content.encode("utf-8"))
else:
source_s3_key_object.download_fileobj(Fileobj=f_source)
f_source.flush()
if self.transform_script is not None:
process = subprocess.Popen(
[self.transform_script, f_source.name, f_dest.name],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True
)
self.log.info("Output:")
for line in iter(process.stdout.readline, b''):
self.log.info(line.decode(self.output_encoding).rstrip())
process.wait()
if process.returncode > 0:
raise AirflowException(
"Transform script failed: {0}".format(process.returncode)
)
else:
self.log.info(
"Transform script successful. Output temporarily located at %s",
f_dest.name
)
self.log.info("Uploading transformed file to S3")
f_dest.flush()
dest_s3.load_file(
filename=f_dest.name,
key=self.dest_s3_key,
replace=self.replace
)
self.log.info("Upload successful")

Просмотреть файл

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterable
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
class S3ListOperator(BaseOperator):
"""
List all objects from the bucket with the given string prefix in name.
This operator returns a python list with the name of objects which can be
used by `xcom` in the downstream task.
:param bucket: The S3 bucket where to find the objects. (templated)
:type bucket: str
:param prefix: Prefix string to filters the objects whose name begin with
such prefix. (templated)
:type prefix: str
:param delimiter: the delimiter marks key hierarchy. (templated)
:type delimiter: str
:param aws_conn_id: The connection ID to use when connecting to S3 storage.
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
**Example**:
The following operator would list all the files
(excluding subfolders) from the S3
``customers/2018/04/`` key in the ``data`` bucket. ::
s3_file = S3ListOperator(
task_id='list_3s_files',
bucket='data',
prefix='customers/2018/04/',
delimiter='/',
aws_conn_id='aws_customers_conn'
)
"""
template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str]
ui_color = '#ffd700'
@apply_defaults
def __init__(self,
bucket,
prefix='',
delimiter='',
aws_conn_id='aws_default',
verify=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.bucket = bucket
self.prefix = prefix
self.delimiter = delimiter
self.aws_conn_id = aws_conn_id
self.verify = verify
def execute(self, context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
self.log.info(
'Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)',
self.bucket, self.prefix, self.delimiter
)
return hook.list_keys(
bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter)

Просмотреть файл

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Iterable
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.utils.decorators import apply_defaults
class SageMakerBaseOperator(BaseOperator):
"""
This is the base operator for all SageMaker operators.
:param config: The configuration necessary to start a training job (templated)
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""
template_fields = ['config']
template_ext = ()
ui_color = '#ededed'
integer_fields = [] # type: Iterable[Iterable[str]]
@apply_defaults
def __init__(self,
config,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
self.config = config
self.hook = None
def parse_integer(self, config, field):
if len(field) == 1:
if isinstance(config, list):
for sub_config in config:
self.parse_integer(sub_config, field)
return
head = field[0]
if head in config:
config[head] = int(config[head])
return
if isinstance(config, list):
for sub_config in config:
self.parse_integer(sub_config, field)
return
head, tail = field[0], field[1:]
if head in config:
self.parse_integer(config[head], tail)
return
def parse_config_integers(self):
# Parse the integer fields of training config to integers
# in case the config is rendered by Jinja and all fields are str
for field in self.integer_fields:
self.parse_integer(self.config, field)
def expand_role(self):
pass
def preprocess_config(self):
self.log.info(
'Preprocessing the config and doing required s3_operations'
)
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.hook.configure_s3_resources(self.config)
self.parse_config_integers()
self.expand_role()
self.log.info(
'After preprocessing the config is:\n {}'.format(
json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
)
def execute(self, context):
raise NotImplementedError('Please implement execute() in sub class!')

Просмотреть файл

@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from airflow.utils.decorators import apply_defaults
class SageMakerEndpointOperator(SageMakerBaseOperator):
"""
Create a SageMaker endpoint.
This operator returns The ARN of the endpoint created in Amazon SageMaker
:param config:
The configuration necessary to create an endpoint.
If you need to create a SageMaker endpoint based on an existed
SageMaker model and an existed SageMaker endpoint config::
config = endpoint_configuration;
If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
config = {
'Model': model_configuration,
'EndpointConfig': endpoint_config_configuration,
'Endpoint': endpoint_configuration
}
For details of the configuration parameter of model_configuration see
:py:meth:`SageMaker.Client.create_model`
For details of the configuration parameter of endpoint_config_configuration see
:py:meth:`SageMaker.Client.create_endpoint_config`
For details of the configuration parameter of endpoint_configuration see
:py:meth:`SageMaker.Client.create_endpoint`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
waits before polling the status of the endpoint creation.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
:type max_ingestion_time: int
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
:type operation: str
"""
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
check_interval=30,
max_ingestion_time=None,
operation='create',
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.operation = operation.lower()
if self.operation not in ['create', 'update']:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
self.create_integer_fields()
def create_integer_fields(self):
if 'EndpointConfig' in self.config:
self.integer_fields = [
['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
]
def expand_role(self):
if 'Model' not in self.config:
return
hook = AwsHook(self.aws_conn_id)
config = self.config['Model']
if 'ExecutionRoleArn' in config:
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
def execute(self, context):
self.preprocess_config()
model_info = self.config.get('Model')
endpoint_config_info = self.config.get('EndpointConfig')
endpoint_info = self.config.get('Endpoint', self.config)
if model_info:
self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
self.hook.create_model(model_info)
if endpoint_config_info:
self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
self.hook.create_endpoint_config(endpoint_config_info)
if self.operation == 'create':
sagemaker_operation = self.hook.create_endpoint
log_str = 'Creating'
elif self.operation == 'update':
sagemaker_operation = self.hook.update_endpoint
log_str = 'Updating'
else:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(
'Sagemaker endpoint creation failed: %s' % response)
else:
return {
'EndpointConfig': self.hook.describe_endpoint_config(
endpoint_info['EndpointConfigName']
),
'Endpoint': self.hook.describe_endpoint(
endpoint_info['EndpointName']
)
}

Просмотреть файл

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from airflow.utils.decorators import apply_defaults
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
"""
Create a SageMaker endpoint config.
This operator returns The ARN of the endpoint config created in Amazon SageMaker
:param config: The configuration necessary to create an endpoint config.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""
integer_fields = [
['ProductionVariants', 'InitialInstanceCount']
]
@apply_defaults
def __init__(self,
config,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
def execute(self, context):
self.preprocess_config()
self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
response = self.hook.create_endpoint_config(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(
'Sagemaker endpoint config creation failed: %s' % response)
else:
return {
'EndpointConfig': self.hook.describe_endpoint_config(
self.config['EndpointConfigName']
)
}

Просмотреть файл

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from airflow.utils.decorators import apply_defaults
class SageMakerModelOperator(SageMakerBaseOperator):
"""
Create a SageMaker model.
This operator returns The ARN of the model created in Amazon SageMaker
:param config: The configuration necessary to create a model.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""
@apply_defaults
def __init__(self,
config,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
def expand_role(self):
if 'ExecutionRoleArn' in self.config:
hook = AwsHook(self.aws_conn_id)
self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
def execute(self, context):
self.preprocess_config()
self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
response = self.hook.create_model(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker model creation failed: %s' % response)
else:
return {
'Model': self.hook.describe_model(
self.config['ModelName']
)
}

Просмотреть файл

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from airflow.utils.decorators import apply_defaults
class SageMakerTrainingOperator(SageMakerBaseOperator):
"""
Initiate a SageMaker training job.
This operator returns The ARN of the training job created in Amazon SageMaker.
:param config: The configuration necessary to start a training job (templated).
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
that the operation waits to check the status of the training job.
:type wait_for_completion: bool
:param print_log: if the operator should print the cloudwatch log during training
:type print_log: bool
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the training job
:type check_interval: int
:param max_ingestion_time: If wait is set to True, the operation fails if the training job
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
the operation does not timeout.
:type max_ingestion_time: int
"""
integer_fields = [
['ResourceConfig', 'InstanceCount'],
['ResourceConfig', 'VolumeSizeInGB'],
['StoppingCondition', 'MaxRuntimeInSeconds']
]
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
print_log=True,
check_interval=30,
max_ingestion_time=None,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
def expand_role(self):
if 'RoleArn' in self.config:
hook = AwsHook(self.aws_conn_id)
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
def execute(self, context):
self.preprocess_config()
self.log.info('Creating SageMaker Training Job %s.', self.config['TrainingJobName'])
response = self.hook.create_training_job(
self.config,
wait_for_completion=self.wait_for_completion,
print_log=self.print_log,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Training Job creation failed: %s' % response)
else:
return {
'Training': self.hook.describe_training_job(
self.config['TrainingJobName']
)
}

Просмотреть файл

@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from airflow.utils.decorators import apply_defaults
class SageMakerTransformOperator(SageMakerBaseOperator):
"""
Initiate a SageMaker transform job.
This operator returns The ARN of the model created in Amazon SageMaker.
:param config: The configuration necessary to start a transform job (templated).
If you need to create a SageMaker transform job based on an existed SageMaker model::
config = transform_config
If you need to create both SageMaker model and SageMaker Transform job::
config = {
'Model': model_config,
'Transform': transform_config
}
For details of the configuration parameter of transform_config see
:py:meth:`SageMaker.Client.create_transform_job`
For details of the configuration parameter of model_config, See:
:py:meth:`SageMaker.Client.create_model`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Set to True to wait until the transform job finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the transform job.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
:type max_ingestion_time: int
"""
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
check_interval=30,
max_ingestion_time=None,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.create_integer_fields()
def create_integer_fields(self):
self.integer_fields = [
['Transform', 'TransformResources', 'InstanceCount'],
['Transform', 'MaxConcurrentTransforms'],
['Transform', 'MaxPayloadInMB']
]
if 'Transform' not in self.config:
for field in self.integer_fields:
field.pop(0)
def expand_role(self):
if 'Model' not in self.config:
return
config = self.config['Model']
if 'ExecutionRoleArn' in config:
hook = AwsHook(self.aws_conn_id)
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
def execute(self, context):
self.preprocess_config()
model_config = self.config.get('Model')
transform_config = self.config.get('Transform', self.config)
if model_config:
self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
self.hook.create_model(model_config)
self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
response = self.hook.create_transform_job(
transform_config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker transform Job creation failed: %s' % response)
else:
return {
'Model': self.hook.describe_model(
transform_config['ModelName']
),
'Transform': self.hook.describe_transform_job(
transform_config['TransformJobName']
)
}

Просмотреть файл

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from airflow.utils.decorators import apply_defaults
class SageMakerTuningOperator(SageMakerBaseOperator):
"""
Initiate a SageMaker hyperparameter tuning job.
This operator returns The ARN of the tuning job created in Amazon SageMaker.
:param config: The configuration necessary to start a tuning job (templated).
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Set to True to wait until the tuning job finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the tuning job.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, the operation fails
if the tuning job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
:type max_ingestion_time: int
"""
integer_fields = [
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds']
]
@apply_defaults
def __init__(self,
config,
wait_for_completion=True,
check_interval=30,
max_ingestion_time=None,
*args, **kwargs):
super().__init__(config=config,
*args, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
def expand_role(self):
if 'TrainingJobDefinition' in self.config:
config = self.config['TrainingJobDefinition']
if 'RoleArn' in config:
hook = AwsHook(self.aws_conn_id)
config['RoleArn'] = hook.expand_role(config['RoleArn'])
def execute(self, context):
self.preprocess_config()
self.log.info(
'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
)
response = self.hook.create_tuning_job(
self.config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Tuning Job creation failed: %s' % response)
else:
return {
'Tuning': self.hook.describe_tuning_job(
self.config['HyperParameterTuningJobName']
)
}

Просмотреть файл

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class EmrBaseSensor(BaseSensorOperator):
"""
Contains general sensor behavior for EMR.
Subclasses should implement get_emr_response() and state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE constants.
"""
ui_color = '#66c3ff'
@apply_defaults
def __init__(
self,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
def poke(self, context):
response = self.get_emr_response()
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
self.log.info('Bad HTTP response: %s', response)
return False
state = self.state_from_response(response)
self.log.info('Job flow currently %s', state)
if state in self.NON_TERMINAL_STATES:
return False
if state in self.FAILED_STATE:
final_message = 'EMR job failed'
failure_message = self.failure_message_from_response(response)
if failure_message:
final_message += ' ' + failure_message
raise AirflowException(final_message)
return True

Просмотреть файл

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
from airflow.utils.decorators import apply_defaults
class EmrJobFlowSensor(EmrBaseSensor):
"""
Asks for the state of the JobFlow until it reaches a terminal state.
If it fails the sensor errors, failing the task.
:param job_flow_id: job_flow_id to check the state of
:type job_flow_id: str
"""
NON_TERMINAL_STATES = ['STARTING', 'BOOTSTRAPPING', 'RUNNING',
'WAITING', 'TERMINATING']
FAILED_STATE = ['TERMINATED_WITH_ERRORS']
template_fields = ['job_flow_id']
template_ext = ()
@apply_defaults
def __init__(self,
job_flow_id,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_flow_id = job_flow_id
def get_emr_response(self):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Poking cluster %s', self.job_flow_id)
return emr.describe_cluster(ClusterId=self.job_flow_id)
@staticmethod
def state_from_response(response):
return response['Cluster']['Status']['State']
@staticmethod
def failure_message_from_response(response):
state_change_reason = response['Cluster']['Status'].get('StateChangeReason')
if state_change_reason:
return 'for code: {} with message {}'.format(state_change_reason.get('Code', 'No code'),
state_change_reason.get('Message', 'Unknown'))
return None

Просмотреть файл

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
from airflow.utils.decorators import apply_defaults
class EmrStepSensor(EmrBaseSensor):
"""
Asks for the state of the step until it reaches a terminal state.
If it fails the sensor errors, failing the task.
:param job_flow_id: job_flow_id which contains the step check the state of
:type job_flow_id: str
:param step_id: step to check the state of
:type step_id: str
"""
NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE', 'CANCEL_PENDING']
FAILED_STATE = ['CANCELLED', 'FAILED', 'INTERRUPTED']
template_fields = ['job_flow_id', 'step_id']
template_ext = ()
@apply_defaults
def __init__(self,
job_flow_id,
step_id,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_flow_id = job_flow_id
self.step_id = step_id
def get_emr_response(self):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id)
return emr.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
@staticmethod
def state_from_response(response):
return response['Step']['Status']['State']
@staticmethod
def failure_message_from_response(response):
fail_details = response['Step']['Status'].get('FailureDetails')
if fail_details:
return 'for reason {} with message {} and log file {}'.format(fail_details.get('Reason'),
fail_details.get('Message'),
fail_details.get('LogFile'))
return None

Просмотреть файл

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
"""
Waits for a partition to show up in AWS Glue Catalog.
:param table_name: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:type table_name: str
:param expression: The partition clause to wait for. This is passed as
is to the AWS Glue Catalog API's get_partitions function,
and supports SQL like notation as in ``ds='2015-01-01'
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:type aws_conn_id: str
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
if not specified.
:type region_name: str
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param poke_interval: Time in seconds that the job should wait in
between each tries
:type poke_interval: int
"""
template_fields = ('database_name', 'table_name', 'expression',)
ui_color = '#C5CAE9'
@apply_defaults
def __init__(self,
table_name, expression="ds='{{ ds }}'",
aws_conn_id='aws_default',
region_name=None,
database_name='default',
poke_interval=60 * 3,
*args,
**kwargs):
super().__init__(
poke_interval=poke_interval, *args, **kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.table_name = table_name
self.expression = expression
self.database_name = database_name
def poke(self, context):
"""
Checks for existence of the partition in the AWS Glue Catalog table
"""
if '.' in self.table_name:
self.database_name, self.table_name = self.table_name.split('.')
self.log.info(
'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression
)
return self.get_hook().check_for_partition(
self.database_name, self.table_name, self.expression)
def get_hook(self):
"""
Gets the AwsGlueCatalogHook
"""
if not hasattr(self, 'hook'):
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
self.hook = AwsGlueCatalogHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
return self.hook

Просмотреть файл

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from urllib.parse import urlparse
from airflow.exceptions import AirflowException
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class S3KeySensor(BaseSensorOperator):
"""
Waits for a key (a file-like instance on S3) to be present in a S3 bucket.
S3 being a key/value it does not support folders. The path is just a key
a resource.
:param bucket_key: The key being waited on. Supports full s3:// style url
or relative path from root level. When it's specified as a full s3://
url, please leave bucket_name as `None`.
:type bucket_key: str
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
:type bucket_name: str
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:type wildcard_match: bool
:param aws_conn_id: a reference to the s3 connection
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('bucket_key', 'bucket_name')
@apply_defaults
def __init__(self,
bucket_key,
bucket_name=None,
wildcard_match=False,
aws_conn_id='aws_default',
verify=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
# Parse
if bucket_name is None:
parsed_url = urlparse(bucket_key)
if parsed_url.netloc == '':
raise AirflowException('Please provide a bucket_name')
else:
bucket_name = parsed_url.netloc
bucket_key = parsed_url.path.lstrip('/')
else:
parsed_url = urlparse(bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
raise AirflowException('If bucket_name is provided, bucket_key' +
' should be relative path from root' +
' level, rather than a full s3:// url')
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.verify = verify
def poke(self, context):
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
if self.wildcard_match:
return hook.check_for_wildcard_key(self.bucket_key,
self.bucket_name)
return hook.check_for_key(self.bucket_key, self.bucket_name)

Просмотреть файл

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class S3PrefixSensor(BaseSensorOperator):
"""
Waits for a prefix to exist. A prefix is the first part of a key,
thus enabling checking of constructs similar to glob airfl* or
SQL LIKE 'airfl%'. There is the possibility to precise a delimiter to
indicate the hierarchy or keys, meaning that the match will stop at that
delimiter. Current code accepts sane delimiters, i.e. characters that
are NOT special characters in the Python regex engine.
:param bucket_name: Name of the S3 bucket
:type bucket_name: str
:param prefix: The prefix being waited on. Relative path from bucket root level.
:type prefix: str
:param delimiter: The delimiter intended to show hierarchy.
Defaults to '/'.
:type delimiter: str
:param aws_conn_id: a reference to the s3 connection
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('prefix', 'bucket_name')
@apply_defaults
def __init__(self,
bucket_name,
prefix,
delimiter='/',
aws_conn_id='aws_default',
verify=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
# Parse
self.bucket_name = bucket_name
self.prefix = prefix
self.delimiter = delimiter
self.full_url = "s3://" + bucket_name + '/' + prefix
self.aws_conn_id = aws_conn_id
self.verify = verify
def poke(self, context):
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name)
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
return hook.check_for_prefix(
prefix=self.prefix,
delimiter=self.delimiter,
bucket_name=self.bucket_name)

Просмотреть файл

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.exceptions import AirflowException
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class SageMakerBaseSensor(BaseSensorOperator):
"""
Contains general sensor behavior for SageMaker.
Subclasses should implement get_sagemaker_response()
and state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
"""
ui_color = '#ededed'
@apply_defaults
def __init__(
self,
aws_conn_id='aws_default',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
def poke(self, context):
response = self.get_sagemaker_response()
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
self.log.info('Bad HTTP response: %s', response)
return False
state = self.state_from_response(response)
self.log.info('Job currently %s', state)
if state in self.non_terminal_states():
return False
if state in self.failed_states():
failed_reason = self.get_failed_reason_from_response(response)
raise AirflowException('Sagemaker job failed for the following reason: %s'
% failed_reason)
return True
def non_terminal_states(self):
raise NotImplementedError('Please implement non_terminal_states() in subclass')
def failed_states(self):
raise NotImplementedError('Please implement failed_states() in subclass')
def get_sagemaker_response(self):
raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
def get_failed_reason_from_response(self, response):
return 'Unknown'
def state_from_response(self, response):
raise NotImplementedError('Please implement state_from_response() in subclass')

Просмотреть файл

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
class SageMakerEndpointSensor(SageMakerBaseSensor):
"""
Asks for the state of the endpoint state until it reaches a terminal state.
If it fails the sensor errors, the task fails.
:param job_name: job_name of the endpoint instance to check the state of
:type job_name: str
"""
template_fields = ['endpoint_name']
template_ext = ()
@apply_defaults
def __init__(self,
endpoint_name,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.endpoint_name = endpoint_name
def non_terminal_states(self):
return SageMakerHook.endpoint_non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
return sagemaker.describe_endpoint(self.endpoint_name)
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['EndpointStatus']

Просмотреть файл

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
class SageMakerTrainingSensor(SageMakerBaseSensor):
"""
Asks for the state of the training state until it reaches a terminal state.
If it fails the sensor errors, failing the task.
:param job_name: name of the SageMaker training job to check the state of
:type job_name: str
:param print_log: if the operator should print the cloudwatch log
:type print_log: bool
"""
template_fields = ['job_name']
template_ext = ()
@apply_defaults
def __init__(self,
job_name,
print_log=True,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_name = job_name
self.print_log = print_log
self.positions = {}
self.stream_names = []
self.instance_count = None
self.state = None
self.last_description = None
self.last_describe_job_call = None
self.log_resource_inited = False
def init_log_resource(self, hook):
description = hook.describe_training_job(self.job_name)
self.instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']
job_already_completed = status not in self.non_terminal_states()
self.state = LogState.TAILING if not job_already_completed else LogState.COMPLETE
self.last_description = description
self.last_describe_job_call = time.time()
self.log_resource_inited = True
def non_terminal_states(self):
return SageMakerHook.non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
if self.print_log:
if not self.log_resource_inited:
self.init_log_resource(sagemaker_hook)
self.state, self.last_description, self.last_describe_job_call = \
sagemaker_hook.describe_training_job_with_log(self.job_name,
self.positions, self.stream_names,
self.instance_count, self.state,
self.last_description,
self.last_describe_job_call)
else:
self.last_description = sagemaker_hook.describe_training_job(self.job_name)
status = self.state_from_response(self.last_description)
if status not in self.non_terminal_states() and status not in self.failed_states():
billable_time = \
(self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
self.last_description['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)
return self.last_description
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['TrainingJobStatus']

Просмотреть файл

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
class SageMakerTransformSensor(SageMakerBaseSensor):
"""
Asks for the state of the transform state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.
:param job_name: job_name of the transform job instance to check the state of
:type job_name: str
"""
template_fields = ['job_name']
template_ext = ()
@apply_defaults
def __init__(self,
job_name,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_name = job_name
def non_terminal_states(self):
return SageMakerHook.non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
return sagemaker.describe_transform_job(self.job_name)
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['TransformJobStatus']

Просмотреть файл

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults
class SageMakerTuningSensor(SageMakerBaseSensor):
"""
Asks for the state of the tuning state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.
:param job_name: job_name of the tuning instance to check the state of
:type job_name: str
"""
template_fields = ['job_name']
template_ext = ()
@apply_defaults
def __init__(self,
job_name,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.job_name = job_name
def non_terminal_states(self):
return SageMakerHook.non_terminal_states
def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
return sagemaker.describe_tuning_job(self.job_name)
def get_failed_reason_from_response(self, response):
return response['FailureReason']
def state_from_response(self, response):
return response['HyperParameterTuningJobStatus']

Просмотреть файл

@ -16,82 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`."""
import warnings
from urllib.parse import urlparse
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor # noqa
from airflow.exceptions import AirflowException
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class S3KeySensor(BaseSensorOperator):
"""
Waits for a key (a file-like instance on S3) to be present in a S3 bucket.
S3 being a key/value it does not support folders. The path is just a key
a resource.
:param bucket_key: The key being waited on. Supports full s3:// style url
or relative path from root level. When it's specified as a full s3://
url, please leave bucket_name as `None`.
:type bucket_key: str
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
:type bucket_name: str
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:type wildcard_match: bool
:param aws_conn_id: a reference to the s3 connection
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('bucket_key', 'bucket_name')
@apply_defaults
def __init__(self,
bucket_key,
bucket_name=None,
wildcard_match=False,
aws_conn_id='aws_default',
verify=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
# Parse
if bucket_name is None:
parsed_url = urlparse(bucket_key)
if parsed_url.netloc == '':
raise AirflowException('Please provide a bucket_name')
else:
bucket_name = parsed_url.netloc
bucket_key = parsed_url.path.lstrip('/')
else:
parsed_url = urlparse(bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
raise AirflowException('If bucket_name is provided, bucket_key' +
' should be relative path from root' +
' level, rather than a full s3:// url')
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.verify = verify
def poke(self, context):
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
if self.wildcard_match:
return hook.check_for_wildcard_key(self.bucket_key,
self.bucket_name)
return hook.check_for_key(self.bucket_key, self.bucket_name)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -16,66 +16,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`."""
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.s3_prefix import S3PrefixSensor # noqa
class S3PrefixSensor(BaseSensorOperator):
"""
Waits for a prefix to exist. A prefix is the first part of a key,
thus enabling checking of constructs similar to glob airfl* or
SQL LIKE 'airfl%'. There is the possibility to precise a delimiter to
indicate the hierarchy or keys, meaning that the match will stop at that
delimiter. Current code accepts sane delimiters, i.e. characters that
are NOT special characters in the Python regex engine.
:param bucket_name: Name of the S3 bucket
:type bucket_name: str
:param prefix: The prefix being waited on. Relative path from bucket root level.
:type prefix: str
:param delimiter: The delimiter intended to show hierarchy.
Defaults to '/'.
:type delimiter: str
:param aws_conn_id: a reference to the s3 connection
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
template_fields = ('prefix', 'bucket_name')
@apply_defaults
def __init__(self,
bucket_name,
prefix,
delimiter='/',
aws_conn_id='aws_default',
verify=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
# Parse
self.bucket_name = bucket_name
self.prefix = prefix
self.delimiter = delimiter
self.full_url = "s3://" + bucket_name + '/' + prefix
self.aws_conn_id = aws_conn_id
self.verify = verify
def poke(self, context):
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name)
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
return hook.check_for_prefix(
prefix=self.prefix,
delimiter=self.delimiter,
bucket_name=self.bucket_name)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`.",
DeprecationWarning, stacklevel=2
)

Просмотреть файл

@ -277,7 +277,7 @@ Airflow provides operators for many common tasks, including:
In addition to these basic building blocks, there are many more specific
operators: :class:`~airflow.operators.docker_operator.DockerOperator`,
:class:`~airflow.providers.apache.hive.operators.hive.HiveOperator`, :class:`~airflow.operators.s3_file_transform_operator.S3FileTransformOperator`,
:class:`~airflow.providers.apache.hive.operators.hive.HiveOperator`, :class:`~airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator`,
:class:`~airflow.operators.presto_to_mysql.PrestoToMySqlTransfer`,
:class:`~airflow.operators.slack_operator.SlackAPIOperator`... you get the idea!

Просмотреть файл

@ -315,9 +315,9 @@ These integrations allow you to perform various operations within the Amazon Web
* - `AWS Glue Catalog <https://aws.amazon.com/glue/>`__
-
- :mod:`airflow.contrib.hooks.aws_glue_catalog_hook`
- :mod:`airflow.providers.amazon.aws.hooks.glue_catalog`
-
- :mod:`airflow.contrib.sensors.aws_glue_catalog_partition_sensor`
- :mod:`airflow.providers.amazon.aws.sensors.glue_catalog_partition`
* - `AWS Lambda <https://aws.amazon.com/lambda/>`__
-
@ -333,7 +333,7 @@ These integrations allow you to perform various operations within the Amazon Web
* - `Amazon CloudWatch Logs <https://aws.amazon.com/cloudwatch/>`__
-
- :mod:`airflow.contrib.hooks.aws_logs_hook`
- :mod:`airflow.providers.amazon.aws.hooks.logs`
-
-
@ -346,18 +346,18 @@ These integrations allow you to perform various operations within the Amazon Web
* - `Amazon EC2 <https://aws.amazon.com/ec2/>`__
-
-
- :mod:`airflow.contrib.operators.ecs_operator`
- :mod:`airflow.providers.amazon.aws.operators.ecs`
-
* - `Amazon EMR <https://aws.amazon.com/emr/>`__
-
- :mod:`airflow.contrib.hooks.emr_hook`
- :mod:`airflow.contrib.operators.emr_add_steps_operator`,
:mod:`airflow.contrib.operators.emr_create_job_flow_operator`,
:mod:`airflow.contrib.operators.emr_terminate_job_flow_operator`
- :mod:`airflow.contrib.sensors.emr_base_sensor`,
:mod:`airflow.contrib.sensors.emr_job_flow_sensor`,
:mod:`airflow.contrib.sensors.emr_step_sensor`
- :mod:`airflow.providers.amazon.aws.hooks.emr`
- :mod:`airflow.providers.amazon.aws.operators.emr_add_steps`,
:mod:`airflow.providers.amazon.aws.operators.emr_create_job_flow`,
:mod:`airflow.providers.amazon.aws.operators.emr_terminate_job_flow`
- :mod:`airflow.providers.amazon.aws.sensors.emr_base`,
:mod:`airflow.providers.amazon.aws.sensors.emr_job_flow`,
:mod:`airflow.providers.amazon.aws.sensors.emr_step`
* - `Amazon Kinesis Data Firehose <https://aws.amazon.com/kinesis/data-firehose/>`__
-
@ -373,19 +373,19 @@ These integrations allow you to perform various operations within the Amazon Web
* - `Amazon SageMaker <https://aws.amazon.com/sagemaker/>`__
-
- :mod:`airflow.contrib.hooks.sagemaker_hook`
- :mod:`airflow.contrib.operators.sagemaker_base_operator`,
:mod:`airflow.contrib.operators.sagemaker_endpoint_config_operator`,
:mod:`airflow.contrib.operators.sagemaker_endpoint_operator`,
:mod:`airflow.contrib.operators.sagemaker_model_operator`,
:mod:`airflow.contrib.operators.sagemaker_training_operator`,
:mod:`airflow.contrib.operators.sagemaker_transform_operator`,
:mod:`airflow.contrib.operators.sagemaker_tuning_operator`
- :mod:`airflow.contrib.sensors.sagemaker_base_sensor`,
:mod:`airflow.contrib.sensors.sagemaker_endpoint_sensor`,
:mod:`airflow.contrib.sensors.sagemaker_training_sensor`,
:mod:`airflow.contrib.sensors.sagemaker_transform_sensor`,
:mod:`airflow.contrib.sensors.sagemaker_tuning_sensor`
- :mod:`airflow.providers.amazon.aws.hooks.sagemaker`
- :mod:`airflow.providers.amazon.aws.operators.sagemaker_base`,
:mod:`airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`,
:mod:`airflow.providers.amazon.aws.operators.sagemaker_endpoint`,
:mod:`airflow.providers.amazon.aws.operators.sagemaker_model`,
:mod:`airflow.providers.amazon.aws.operators.sagemaker_training`,
:mod:`airflow.providers.amazon.aws.operators.sagemaker_transform`,
:mod:`airflow.providers.amazon.aws.operators.sagemaker_tuning`
- :mod:`airflow.providers.amazon.aws.sensors.sagemaker_base`,
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_endpoint`,
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_training`,
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_transform`,
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_tuning`
* - `Amazon Simple Notification Service (SNS) <https://aws.amazon.com/sns/>`__
-
@ -402,12 +402,12 @@ These integrations allow you to perform various operations within the Amazon Web
* - `Amazon Simple Storage Service (S3) <https://aws.amazon.com/s3/>`__
-
- :mod:`airflow.providers.amazon.aws.hooks.s3`
- :mod:`airflow.operators.s3_file_transform_operator`,
:mod:`airflow.contrib.operators.s3_copy_object_operator`,
:mod:`airflow.contrib.operators.s3_delete_objects_operator`,
:mod:`airflow.contrib.operators.s3_list_operator`
- :mod:`airflow.sensors.s3_key_sensor`,
:mod:`airflow.sensors.s3_prefix_sensor`
- :mod:`airflow.providers.amazon.aws.operators.s3_file_transform`,
:mod:`airflow.providers.amazon.aws.operators.s3_copy_object`,
:mod:`airflow.providers.amazon.aws.operators.s3_delete_objects`,
:mod:`airflow.providers.amazon.aws.operators.s3_list`
- :mod:`airflow.providers.amazon.aws.sensors.s3_key`,
:mod:`airflow.providers.amazon.aws.sensors.s3_prefix`
Transfer operators and hooks
''''''''''''''''''''''''''''

Просмотреть файл

@ -11,7 +11,7 @@
./airflow/contrib/hooks/datadog_hook.py
./airflow/contrib/hooks/dingding_hook.py
./airflow/contrib/hooks/discord_webhook_hook.py
./airflow/contrib/hooks/emr_hook.py
./airflow/providers/amazon/aws/hooks/emr.py
./airflow/hooks/filesystem.py
./airflow/contrib/hooks/ftp_hook.py
./airflow/contrib/hooks/jenkins_hook.py
@ -19,7 +19,7 @@
./airflow/contrib/hooks/opsgenie_alert_hook.py
./airflow/providers/apache/pinot/hooks/pinot.py
./airflow/contrib/hooks/qubole_check_hook.py
./airflow/contrib/hooks/sagemaker_hook.py
./airflow/providers/amazon/aws/hooks/sagemaker.py
./airflow/contrib/hooks/segment_hook.py
./airflow/contrib/hooks/slack_webhook_hook.py
./airflow/contrib/hooks/snowflake_hook.py
@ -36,10 +36,10 @@
./airflow/contrib/operators/dingding_operator.py
./airflow/contrib/operators/discord_webhook_operator.py
./airflow/providers/apache/druid/operators/druid.py
./airflow/contrib/operators/ecs_operator.py
./airflow/contrib/operators/emr_add_steps_operator.py
./airflow/contrib/operators/emr_create_job_flow_operator.py
./airflow/contrib/operators/emr_terminate_job_flow_operator.py
./airflow/providers/amazon/aws/operators/ecs.py
./airflow/providers/amazon/aws/operators/emr_add_steps.py
./airflow/providers/amazon/aws/operators/emr_create_job_flow.py
./airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
./airflow/contrib/operators/file_to_wasb.py
./airflow/contrib/operators/grpc_operator.py
./airflow/contrib/operators/jenkins_job_trigger_operator.py
@ -50,18 +50,18 @@
./airflow/contrib/operators/oracle_to_oracle_transfer.py
./airflow/contrib/operators/qubole_check_operator.py
./airflow/contrib/operators/redis_publish_operator.py
./airflow/contrib/operators/s3_copy_object_operator.py
./airflow/contrib/operators/s3_delete_objects_operator.py
./airflow/contrib/operators/s3_list_operator.py
./airflow/providers/amazon/aws/operators/s3_copy_object.py
./airflow/providers/amazon/aws/operators/s3_delete_objects.py
./airflow/providers/amazon/aws/operators/s3_list.py
./airflow/contrib/operators/s3_to_gcs_operator.py
./airflow/contrib/operators/s3_to_sftp_operator.py
./airflow/contrib/operators/sagemaker_base_operator.py
./airflow/contrib/operators/sagemaker_endpoint_config_operator.py
./airflow/contrib/operators/sagemaker_endpoint_operator.py
./airflow/contrib/operators/sagemaker_model_operator.py
./airflow/contrib/operators/sagemaker_training_operator.py
./airflow/contrib/operators/sagemaker_transform_operator.py
./airflow/contrib/operators/sagemaker_tuning_operator.py
./airflow/providers/amazon/aws/operators/sagemaker_base.py
./airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
./airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
./airflow/providers/amazon/aws/operators/sagemaker_model.py
./airflow/providers/amazon/aws/operators/sagemaker_training.py
./airflow/providers/amazon/aws/operators/sagemaker_transform.py
./airflow/providers/amazon/aws/operators/sagemaker_tuning.py
./airflow/contrib/operators/segment_track_event_operator.py
./airflow/contrib/operators/sftp_to_s3_operator.py
./airflow/contrib/operators/slack_webhook_operator.py
@ -75,14 +75,14 @@
./airflow/contrib/operators/vertica_to_mysql.py
./airflow/providers/microsoft/azure/operators/wasb_delete_blob.py
./airflow/contrib/operators/winrm_operator.py
./airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
./airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
./airflow/providers/microsoft/azure/sensors/azure_cosmos.py
./airflow/contrib/sensors/bash_sensor.py
./airflow/contrib/sensors/celery_queue_sensor.py
./airflow/contrib/sensors/datadog_sensor.py
./airflow/contrib/sensors/emr_base_sensor.py
./airflow/contrib/sensors/emr_job_flow_sensor.py
./airflow/contrib/sensors/emr_step_sensor.py
./airflow/providers/amazon/aws/sensors/emr_base.py
./airflow/providers/amazon/aws/sensors/emr_job_flow.py
./airflow/providers/amazon/aws/sensors/emr_step.py
./airflow/sensors/filesystem.py
./airflow/contrib/sensors/ftp_sensor.py
./airflow/providers/apache/hdfs/sensors/hdfs.py
@ -92,11 +92,11 @@
./airflow/contrib/sensors/qubole_sensor.py
./airflow/contrib/sensors/redis_key_sensor.py
./airflow/contrib/sensors/redis_pub_sub_sensor.py
./airflow/contrib/sensors/sagemaker_base_sensor.py
./airflow/contrib/sensors/sagemaker_endpoint_sensor.py
./airflow/contrib/sensors/sagemaker_training_sensor.py
./airflow/contrib/sensors/sagemaker_transform_sensor.py
./airflow/contrib/sensors/sagemaker_tuning_sensor.py
./airflow/providers/amazon/aws/sensors/sagemaker_base.py
./airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
./airflow/providers/amazon/aws/sensors/sagemaker_training.py
./airflow/providers/amazon/aws/sensors/sagemaker_transform.py
./airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
./airflow/providers/microsoft/azure/sensors/wasb.py
./airflow/sensors/weekday_sensor.py
./airflow/hooks/dbapi_hook.py
@ -160,7 +160,7 @@
./airflow/operators/presto_check_operator.py
./airflow/operators/presto_to_mysql.py
./airflow/operators/python_operator.py
./airflow/operators/s3_file_transform_operator.py
./airflow/providers/amazon/aws/operators/s3_file_transform.py
./airflow/operators/s3_to_redshift_operator.py
./airflow/operators/slack_operator.py
./airflow/operators/sqlite_operator.py
@ -175,8 +175,8 @@
./airflow/sensors/http_sensor.py
./airflow/providers/apache/hive/sensors/metastore_partition.py
./airflow/providers/apache/hive/sensors/named_hive_partition.py
./airflow/sensors/s3_key_sensor.py
./airflow/sensors/s3_prefix_sensor.py
./airflow/providers/amazon/aws/sensors/s3_key.py
./airflow/providers/amazon/aws/sensors/s3_prefix.py
./airflow/sensors/sql_sensor.py
./airflow/sensors/time_delta_sensor.py
./airflow/sensors/time_sensor.py

Просмотреть файл

@ -53,7 +53,7 @@ class TestS3ToGoogleCloudStorageOperator(unittest.TestCase):
self.assertEqual(operator.dest_gcs, GCS_PATH_PREFIX)
@mock.patch('airflow.contrib.operators.s3_to_gcs_operator.S3Hook')
@mock.patch('airflow.contrib.operators.s3_list_operator.S3Hook')
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
@mock.patch(
'airflow.contrib.operators.s3_to_gcs_operator.GCSHook')
def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
@ -88,7 +88,7 @@ class TestS3ToGoogleCloudStorageOperator(unittest.TestCase):
self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
@mock.patch('airflow.contrib.operators.s3_to_gcs_operator.S3Hook')
@mock.patch('airflow.contrib.operators.s3_list_operator.S3Hook')
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
@mock.patch(
'airflow.contrib.operators.s3_to_gcs_operator.GCSHook')
def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):

Просмотреть файл

@ -22,7 +22,7 @@ import unittest
import boto3
from airflow.contrib.hooks.emr_hook import EmrHook
from airflow.providers.amazon.aws.hooks.emr import EmrHook
try:
from moto import mock_emr

Просмотреть файл

@ -21,7 +21,7 @@ import unittest
import boto3
import mock
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
try:
from moto import mock_glue

Просмотреть файл

@ -20,7 +20,7 @@
import unittest
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
try:
from moto import mock_logs

Просмотреть файл

@ -25,12 +25,12 @@ from datetime import datetime
import mock
from tzlocal import get_localzone
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
from airflow.contrib.hooks.sagemaker_hook import (
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState, SageMakerHook, secondary_training_status_changed, secondary_training_status_message,
)
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
role = 'arn:aws:iam:role/test-role'

Просмотреть файл

@ -25,8 +25,8 @@ from copy import deepcopy
import mock
from parameterized import parameterized
from airflow.contrib.operators.ecs_operator import ECSOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
RESPONSE_WITHOUT_FAILURES = {
"failures": [],
@ -52,7 +52,7 @@ RESPONSE_WITHOUT_FAILURES = {
class TestECSOperator(unittest.TestCase):
@mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
@mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsHook')
def setUp(self, aws_hook_mock):
self.aws_hook_mock = aws_hook_mock
self.ecs_operator_args = {
@ -96,7 +96,7 @@ class TestECSOperator(unittest.TestCase):
])
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
@mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsHook')
def test_execute_without_failures(self, launch_type, tags, aws_hook_mock,
check_mock, wait_mock):
client_mock = aws_hook_mock.return_value.get_client_type.return_value

Просмотреть файл

@ -22,9 +22,9 @@ from datetime import timedelta
from unittest.mock import MagicMock, patch
from airflow import DAG
from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
from airflow.exceptions import AirflowException
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
@ -111,7 +111,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
with patch('boto3.session.Session', self.boto3_session_mock):
with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
with patch('airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name') \
as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = expected_job_flow_id
@ -132,7 +132,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
def test_init_with_nonexistent_cluster_name(self):
cluster_name = 'test_cluster'
with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
with patch('airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name') \
as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = None

Просмотреть файл

@ -23,8 +23,8 @@ from datetime import timedelta
from unittest.mock import MagicMock, patch
from airflow import DAG
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2017, 1, 1)

Просмотреть файл

@ -20,7 +20,7 @@
import unittest
from unittest.mock import MagicMock, patch
from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator
TERMINATE_SUCCESS_RETURN = {
'ResponseMetadata': {

Просмотреть файл

@ -23,7 +23,7 @@ import unittest
import boto3
from moto import mock_s3
from airflow.contrib.operators.s3_copy_object_operator import S3CopyObjectOperator
from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator
class TestS3CopyObjectOperator(unittest.TestCase):

Просмотреть файл

@ -23,7 +23,7 @@ import unittest
import boto3
from moto import mock_s3
from airflow.contrib.operators.s3_delete_objects_operator import S3DeleteObjectsOperator
from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator
class TestS3DeleteObjectsOperator(unittest.TestCase):

Просмотреть файл

@ -31,7 +31,7 @@ import boto3
from moto import mock_s3
from airflow.exceptions import AirflowException
from airflow.operators.s3_file_transform_operator import S3FileTransformOperator
from airflow.providers.amazon.aws.operators.s3_file_transform import S3FileTransformOperator
class TestS3FileTransformOperator(unittest.TestCase):

Просмотреть файл

@ -21,7 +21,7 @@ import unittest
import mock
from airflow.contrib.operators.s3_list_operator import S3ListOperator
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator
TASK_ID = 'test-s3-list-operator'
BUCKET = 'test-bucket'
@ -31,7 +31,7 @@ MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
class TestS3ListOperator(unittest.TestCase):
@mock.patch('airflow.contrib.operators.s3_list_operator.S3Hook')
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
def test_execute(self, mock_hook):
mock_hook.return_value.list_keys.return_value = MOCK_FILES

Просмотреть файл

@ -19,7 +19,7 @@
import unittest
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
config = {
'key1': '1',

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_endpoint_operator import SageMakerEndpointOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator
role = 'arn:aws:iam:role/test-role'
bucket = 'test-bucket'

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_endpoint_config_operator import SageMakerEndpointConfigOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import SageMakerEndpointConfigOperator
model_name = 'test-model-name'
config_name = 'test-config-name'

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_model_operator import SageMakerModelOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator
role = 'arn:aws:iam:role/test-role'

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator
role = 'arn:aws:iam:role/test-role'

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_transform_operator import SageMakerTransformOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator
role = 'arn:aws:iam:role/test-role'

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_tuning_operator import SageMakerTuningOperator
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator
role = 'arn:aws:iam:role/test-role'

Просмотреть файл

@ -19,8 +19,8 @@
import unittest
from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
class TestEmrBaseSensor(unittest.TestCase):

Просмотреть файл

@ -24,7 +24,7 @@ from unittest.mock import MagicMock, patch
from dateutil.tz import tzlocal
from airflow import AirflowException
from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor
DESCRIBE_CLUSTER_RUNNING_RETURN = {
'Cluster': {

Просмотреть файл

@ -24,7 +24,7 @@ from unittest.mock import MagicMock, patch
from dateutil.tz import tzlocal
from airflow import AirflowException
from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor
DESCRIBE_JOB_STEP_RUNNING_RETURN = {
'ResponseMetadata': {

Просмотреть файл

@ -21,8 +21,8 @@ import unittest
import mock
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
from airflow.contrib.sensors.aws_glue_catalog_partition_sensor import AwsGlueCatalogPartitionSensor
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import AwsGlueCatalogPartitionSensor
try:
from moto import mock_glue

Просмотреть файл

@ -23,7 +23,7 @@ from unittest import mock
from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.sensors.s3_key_sensor import S3KeySensor
from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor
class TestS3KeySensor(unittest.TestCase):

Просмотреть файл

@ -20,7 +20,7 @@
import unittest
from unittest import mock
from airflow.sensors.s3_prefix_sensor import S3PrefixSensor
from airflow.providers.amazon.aws.sensors.s3_prefix import S3PrefixSensor
class TestS3PrefixSensor(unittest.TestCase):

Просмотреть файл

@ -19,8 +19,8 @@
import unittest
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
class TestSagemakerBaseSensor(unittest.TestCase):

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_endpoint_sensor import SageMakerEndpointSensor
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor
DESCRIBE_ENDPOINT_CREATING_RESPONSE = {
'EndpointStatus': 'Creating',

Просмотреть файл

@ -22,10 +22,10 @@ from datetime import datetime
import mock
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
from airflow.contrib.hooks.sagemaker_hook import LogState, SageMakerHook
from airflow.contrib.sensors.sagemaker_training_sensor import SageMakerTrainingSensor
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_training import SageMakerTrainingSensor
DESCRIBE_TRAINING_COMPELETED_RESPONSE = {
'TrainingJobStatus': 'Completed',

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_transform_sensor import SageMakerTransformSensor
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor
DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE = {
'TransformJobStatus': 'InProgress',

Просмотреть файл

@ -21,9 +21,9 @@ import unittest
import mock
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_tuning_sensor import SageMakerTuningSensor
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor
DESCRIBE_TUNING_INPROGRESS_RESPONSE = {
'HyperParameterTuningJobStatus': 'InProgress',

Просмотреть файл

@ -240,7 +240,22 @@ HOOK = [
'airflow.providers.microsoft.azure.hooks.wasb.WasbHook',
'airflow.contrib.hooks.wasb_hook.WasbHook',
),
(
'airflow.providers.amazon.aws.hooks.glue_catalog.AwsGlueCatalogHook',
'airflow.contrib.hooks.aws_glue_catalog_hook.AwsGlueCatalogHook',
),
(
'airflow.providers.amazon.aws.hooks.logs.AwsLogsHook',
'airflow.contrib.hooks.aws_logs_hook.AwsLogsHook',
),
(
'airflow.providers.amazon.aws.hooks.emr.EmrHook',
'airflow.contrib.hooks.emr_hook.EmrHook',
),
(
'airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook',
'airflow.contrib.hooks.sagemaker_hook.SageMakerHook',
),
]
OPERATOR = [
@ -911,6 +926,62 @@ OPERATOR = [
'airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbDeleteBlobOperator',
'airflow.contrib.operators.wasb_delete_blob_operator.WasbDeleteBlobOperator',
),
(
'airflow.providers.amazon.aws.operators.ecs.ECSOperator',
'airflow.contrib.operators.ecs_operator.ECSOperator',
),
(
'airflow.providers.amazon.aws.operators.emr_add_steps.EmrAddStepsOperator',
'airflow.contrib.operators.emr_add_steps_operator.EmrAddStepsOperator',
),
(
'airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrCreateJobFlowOperator',
'airflow.contrib.operators.emr_create_job_flow_operator.EmrCreateJobFlowOperator',
),
(
'airflow.providers.amazon.aws.operators.emr_terminate_job_flow.EmrTerminateJobFlowOperator',
'airflow.contrib.operators.emr_terminate_job_flow_operator.EmrTerminateJobFlowOperator',
),
(
'airflow.providers.amazon.aws.operators.s3_copy_object.S3CopyObjectOperator',
'airflow.contrib.operators.s3_copy_object_operator.S3CopyObjectOperator',
),
(
'airflow.providers.amazon.aws.operators.s3_delete_objects.S3DeleteObjectsOperator',
'airflow.contrib.operators.s3_delete_objects_operator.S3DeleteObjectsOperator',
),
(
'airflow.providers.amazon.aws.operators.s3_list.S3ListOperator',
'airflow.contrib.operators.s3_list_operator.S3ListOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_base.SageMakerBaseOperator',
'airflow.contrib.operators.sagemaker_base_operator.SageMakerBaseOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_endpoint_config.SageMakerEndpointConfigOperator',
'airflow.contrib.operators.sagemaker_endpoint_config_operator.SageMakerEndpointConfigOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_endpoint.SageMakerEndpointOperator',
'airflow.contrib.operators.sagemaker_endpoint_operator.SageMakerEndpointOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_model.SageMakerModelOperator',
'airflow.contrib.operators.sagemaker_model_operator.SageMakerModelOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_training.SageMakerTrainingOperator',
'airflow.contrib.operators.sagemaker_training_operator.SageMakerTrainingOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_transform.SageMakerTransformOperator',
'airflow.contrib.operators.sagemaker_transform_operator.SageMakerTransformOperator',
),
(
'airflow.providers.amazon.aws.operators.sagemaker_tuning.SageMakerTuningOperator',
'airflow.contrib.operators.sagemaker_tuning_operator.SageMakerTuningOperator',
),
]
SENSOR = [
@ -1001,6 +1072,50 @@ SENSOR = [
'airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor',
'airflow.contrib.sensors.wasb_sensor.WasbPrefixSensor',
),
(
'airflow.providers.amazon.aws.sensors.glue_catalog_partition.AwsGlueCatalogPartitionSensor',
'airflow.contrib.sensors.aws_glue_catalog_partition_sensor.AwsGlueCatalogPartitionSensor',
),
(
'airflow.providers.amazon.aws.sensors.emr_base.EmrBaseSensor',
'airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor',
),
(
'airflow.providers.amazon.aws.sensors.emr_job_flow.EmrJobFlowSensor',
'airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor',
),
(
'airflow.providers.amazon.aws.sensors.emr_step.EmrStepSensor',
'airflow.contrib.sensors.emr_step_sensor.EmrStepSensor',
),
(
'airflow.providers.amazon.aws.sensors.sagemaker_base.SageMakerBaseSensor',
'airflow.contrib.sensors.sagemaker_base_sensor.SageMakerBaseSensor',
),
(
'airflow.providers.amazon.aws.sensors.sagemaker_endpoint.SageMakerEndpointSensor',
'airflow.contrib.sensors.sagemaker_endpoint_sensor.SageMakerEndpointSensor',
),
(
'airflow.providers.amazon.aws.sensors.sagemaker_transform.SageMakerTransformSensor',
'airflow.contrib.sensors.sagemaker_transform_sensor.SageMakerTransformSensor',
),
(
'airflow.providers.amazon.aws.sensors.sagemaker_tuning.SageMakerTuningSensor',
'airflow.contrib.sensors.sagemaker_tuning_sensor.SageMakerTuningSensor',
),
(
'airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator',
'airflow.operators.s3_file_transform_operator.S3FileTransformOperator',
),
(
'airflow.providers.amazon.aws.sensors.s3_key.S3KeySensor',
'airflow.sensors.s3_key_sensor.S3KeySensor',
),
(
'airflow.providers.amazon.aws.sensors.s3_prefix.S3PrefixSensor',
'airflow.sensors.s3_prefix_sensor.S3PrefixSensor',
),
]
PROTOCOLS = [
@ -1008,6 +1123,10 @@ PROTOCOLS = [
"airflow.providers.amazon.aws.hooks.batch_client.AwsBatchProtocol",
"airflow.contrib.operators.awsbatch_operator.BatchProtocol",
),
(
'airflow.providers.amazon.aws.operators.ecs.ECSProtocol',
'airflow.contrib.operators.ecs_operator.ECSProtocol',
),
]