[AIRFLOW-6572] Move AWS classes to providers.amazon.aws package (#7178)
* [AIP-21] Move contrib.hooks.aws_glue_catalog_hook airflow.contrib.hooks.aws_glue_catalog_hook * [AIP-21] Move contrib.hooks.aws_logs_hook airflow.contrib.hooks.aws_logs_hook * [AIP-21] Move contrib.hooks.emr_hook airflow.contrib.hooks.emr_hook * [AIP-21] Move contrib.operators.ecs_operator airflow.contrib.operators.ecs_operator * [AIP-21] Move contrib.operators.emr_add_steps_operator airflow.contrib.operators.emr_add_steps_operator * [AIP-21] Move contrib.operators.emr_create_job_flow_operator airflow.contrib.operators.emr_create_job_flow_operator * [AIP-21] Move contrib.operators.emr_terminate_job_flow_operator airflow.contrib.operators.emr_terminate_job_flow_operator * [AIP-21] Move contrib.operators.s3_copy_object_operator airflow.contrib.operators.s3_copy_object_operator * [AIP-21] Move contrib.operators.s3_delete_objects_operator airflow.contrib.operators.s3_delete_objects_operator * [AIP-21] Move contrib.operators.s3_list_operator airflow.contrib.operators.s3_list_operator * [AIP-21] Move contrib.operators.sagemaker_base_operator airflow.contrib.operators.sagemaker_base_operator * [AIP-21] Move contrib.operators.sagemaker_endpoint_config_operator airflow.contrib.operators.sagemaker_endpoint_config_operator * [AIP-21] Move contrib.operators.sagemaker_endpoint_operator airflow.contrib.operators.sagemaker_endpoint_operator * [AIP-21] Move contrib.operators.sagemaker_model_operator airflow.contrib.operators.sagemaker_model_operator * [AIP-21] Move contrib.operators.sagemaker_training_operator airflow.contrib.operators.sagemaker_training_operator * [AIP-21] Move contrib.operators.sagemaker_transform_operator airflow.contrib.operators.sagemaker_transform_operator * [AIP-21] Move contrib.operators.sagemaker_tuning_operator airflow.contrib.operators.sagemaker_tuning_operator * [AIP-21] Move contrib.sensors.aws_glue_catalog_partition_sensor airflow.contrib.sensors.aws_glue_catalog_partition_sensor * [AIP-21] Move contrib.sensors.emr_base_sensor airflow.contrib.sensors.emr_base_sensor * [AIP-21] Move contrib.sensors.emr_job_flow_sensor airflow.contrib.sensors.emr_job_flow_sensor * [AIP-21] Move contrib.sensors.emr_step_sensor airflow.contrib.sensors.emr_step_sensor * [AIP-21] Move contrib.sensors.sagemaker_base_sensor airflow.contrib.sensors.sagemaker_base_sensor * [AIP-21] Move contrib.sensors.sagemaker_endpoint_sensor airflow.contrib.sensors.sagemaker_endpoint_sensor * [AIP-21] Move contrib.sensors.sagemaker_training_sensor airflow.contrib.sensors.sagemaker_training_sensor * [AIP-21] Move contrib.sensors.sagemaker_transform_sensor airflow.contrib.sensors.sagemaker_transform_sensor * [AIP-21] Move contrib.sensors.sagemaker_tuning_sensor airflow.contrib.sensors.sagemaker_tuning_sensor * [AIP-21] Move operators.s3_file_transform_operator airflow.operators.s3_file_transform_operator * [AIP-21] Move sensors.s3_key_sensor airflow.sensors.s3_key_sensor * [AIP-21] Move sensors.s3_prefix_sensor airflow.sensors.s3_prefix_sensor * [AIP-21] Move contrib.hooks.sagemaker_hook providers.amazon.aws.hooks.sagemaker
This commit is contained in:
Родитель
50efda5c69
Коммит
c319e81cae
|
@ -22,8 +22,8 @@ This is an example dag for a AWS EMR Pipeline with auto steps.
|
|||
from datetime import timedelta
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
|
||||
from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor
|
||||
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
|
||||
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor
|
||||
from airflow.utils.dates import days_ago
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
|
|
|
@ -25,10 +25,10 @@ terminating the cluster.
|
|||
from datetime import timedelta
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
|
||||
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
|
||||
from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator
|
||||
from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor
|
||||
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator
|
||||
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
|
||||
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator
|
||||
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor
|
||||
from airflow.utils.dates import days_ago
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
|
|
|
@ -16,137 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`."""
|
||||
|
||||
"""
|
||||
This module contains AWS Glue Catalog Hook
|
||||
"""
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook # noqa
|
||||
|
||||
class AwsGlueCatalogHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS Glue Catalog
|
||||
|
||||
:param aws_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type aws_conn_id: str
|
||||
:param region_name: aws region name (example: us-east-1)
|
||||
:type region_name: str
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
aws_conn_id='aws_default',
|
||||
region_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.region_name = region_name
|
||||
self.conn = None
|
||||
super().__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Returns glue connection object.
|
||||
"""
|
||||
self.conn = self.get_client_type('glue', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def get_partitions(self,
|
||||
database_name,
|
||||
table_name,
|
||||
expression='',
|
||||
page_size=None,
|
||||
max_items=None):
|
||||
"""
|
||||
Retrieves the partition values for a table.
|
||||
|
||||
:param database_name: The name of the catalog database where the partitions reside.
|
||||
:type database_name: str
|
||||
:param table_name: The name of the partitions' table.
|
||||
:type table_name: str
|
||||
:param expression: An expression filtering the partitions to be returned.
|
||||
Please see official AWS documentation for further information.
|
||||
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
|
||||
:type expression: str
|
||||
:param page_size: pagination size
|
||||
:type page_size: int
|
||||
:param max_items: maximum items to return
|
||||
:type max_items: int
|
||||
:return: set of partition values where each value is a tuple since
|
||||
a partition may be composed of multiple columns. For example:
|
||||
``{('2018-01-01','1'), ('2018-01-01','2')}``
|
||||
"""
|
||||
config = {
|
||||
'PageSize': page_size,
|
||||
'MaxItems': max_items,
|
||||
}
|
||||
|
||||
paginator = self.get_conn().get_paginator('get_partitions')
|
||||
response = paginator.paginate(
|
||||
DatabaseName=database_name,
|
||||
TableName=table_name,
|
||||
Expression=expression,
|
||||
PaginationConfig=config
|
||||
)
|
||||
|
||||
partitions = set()
|
||||
for page in response:
|
||||
for partition in page['Partitions']:
|
||||
partitions.add(tuple(partition['Values']))
|
||||
|
||||
return partitions
|
||||
|
||||
def check_for_partition(self, database_name, table_name, expression):
|
||||
"""
|
||||
Checks whether a partition exists
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table @partition belongs to
|
||||
:type table_name: str
|
||||
:expression: Expression that matches the partitions to check for
|
||||
(eg `a = 'b' AND c = 'd'`)
|
||||
:type expression: str
|
||||
:rtype: bool
|
||||
|
||||
>>> hook = AwsGlueCatalogHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
|
||||
True
|
||||
"""
|
||||
partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
|
||||
|
||||
return bool(partitions)
|
||||
|
||||
def get_table(self, database_name, table_name):
|
||||
"""
|
||||
Get the information of the table
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table
|
||||
:type table_name: str
|
||||
:rtype: dict
|
||||
|
||||
>>> hook = AwsGlueCatalogHook()
|
||||
>>> r = hook.get_table('db', 'table_foo')
|
||||
>>> r['Name'] = 'table_foo'
|
||||
"""
|
||||
|
||||
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name)
|
||||
|
||||
return result['Table']
|
||||
|
||||
def get_table_location(self, database_name, table_name):
|
||||
"""
|
||||
Get the physical location of the table
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table
|
||||
:type table_name: str
|
||||
:return: str
|
||||
"""
|
||||
|
||||
table = self.get_table(database_name, table_name)
|
||||
|
||||
return table['StorageDescriptor']['Location']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,87 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`."""
|
||||
|
||||
"""
|
||||
This module contains a hook (AwsLogsHook) with some very basic
|
||||
functionality for interacting with AWS CloudWatch.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook # noqa
|
||||
|
||||
|
||||
class AwsLogsHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS CloudWatch Logs
|
||||
|
||||
:param region_name: AWS Region Name (example: us-west-2)
|
||||
:type region_name: str
|
||||
"""
|
||||
|
||||
def __init__(self, region_name=None, *args, **kwargs):
|
||||
self.region_name = region_name
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Establish an AWS connection for retrieving logs.
|
||||
|
||||
:rtype: CloudWatchLogs.Client
|
||||
"""
|
||||
return self.get_client_type('logs', region_name=self.region_name)
|
||||
|
||||
def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start_from_head=True):
|
||||
"""
|
||||
A generator for log items in a single stream. This will yield all the
|
||||
items that are available at the current moment.
|
||||
|
||||
:param log_group: The name of the log group.
|
||||
:type log_group: str
|
||||
:param log_stream_name: The name of the specific stream.
|
||||
:type log_stream_name: str
|
||||
:param start_time: The time stamp value to start reading the logs from (default: 0).
|
||||
:type start_time: int
|
||||
:param skip: The number of log entries to skip at the start (default: 0).
|
||||
This is for when there are multiple entries at the same timestamp.
|
||||
:type skip: int
|
||||
:param start_from_head: whether to start from the beginning (True) of the log or
|
||||
at the end of the log (False).
|
||||
:type start_from_head: bool
|
||||
:rtype: dict
|
||||
:return: | A CloudWatch log event with the following key-value pairs:
|
||||
| 'timestamp' (int): The time in milliseconds of the event.
|
||||
| 'message' (str): The log event data.
|
||||
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
|
||||
"""
|
||||
|
||||
next_token = None
|
||||
|
||||
event_count = 1
|
||||
while event_count > 0:
|
||||
if next_token is not None:
|
||||
token_arg = {'nextToken': next_token}
|
||||
else:
|
||||
token_arg = {}
|
||||
|
||||
response = self.get_conn().get_log_events(logGroupName=log_group,
|
||||
logStreamName=log_stream_name,
|
||||
startTime=start_time,
|
||||
startFromHead=start_from_head,
|
||||
**token_arg)
|
||||
|
||||
events = response['events']
|
||||
event_count = len(events)
|
||||
|
||||
if event_count > skip:
|
||||
events = events[skip:]
|
||||
skip = 0
|
||||
else:
|
||||
skip = skip - event_count
|
||||
events = []
|
||||
|
||||
yield from events
|
||||
|
||||
if 'nextForwardToken' in response:
|
||||
next_token = response['nextForwardToken']
|
||||
else:
|
||||
return
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,65 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook # noqa
|
||||
|
||||
class EmrHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS EMR. emr_conn_id is only necessary for using the
|
||||
create_job_flow method.
|
||||
"""
|
||||
|
||||
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
|
||||
self.emr_conn_id = emr_conn_id
|
||||
self.region_name = region_name
|
||||
self.conn = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
if not self.conn:
|
||||
self.conn = self.get_client_type('emr', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
|
||||
conn = self.get_conn()
|
||||
|
||||
response = conn.list_clusters(
|
||||
ClusterStates=cluster_states
|
||||
)
|
||||
|
||||
matching_clusters = list(
|
||||
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
|
||||
)
|
||||
|
||||
if len(matching_clusters) == 1:
|
||||
cluster_id = matching_clusters[0]['Id']
|
||||
self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id)
|
||||
return cluster_id
|
||||
elif len(matching_clusters) > 1:
|
||||
raise AirflowException('More than one cluster found for name %s', emr_cluster_name)
|
||||
else:
|
||||
self.log.info('No cluster found for name %s', emr_cluster_name)
|
||||
return None
|
||||
|
||||
def create_job_flow(self, job_flow_overrides):
|
||||
"""
|
||||
Creates a job flow using the config from the EMR connection.
|
||||
Keys of the json extra hash may have the arguments of the boto3
|
||||
run_job_flow method.
|
||||
Overrides for this config may be passed as the job_flow_overrides.
|
||||
"""
|
||||
|
||||
if not self.emr_conn_id:
|
||||
raise AirflowException('emr_conn_id must be present to use create_job_flow')
|
||||
|
||||
emr_conn = self.get_connection(self.emr_conn_id)
|
||||
|
||||
config = emr_conn.extra_dejson.copy()
|
||||
config.update(job_flow_overrides)
|
||||
|
||||
response = self.get_conn().run_job_flow(**config)
|
||||
|
||||
return response
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,726 +16,17 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import collections
|
||||
import os
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`."""
|
||||
|
||||
import warnings
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils import timezone
|
||||
|
||||
|
||||
class LogState:
|
||||
STARTING = 1
|
||||
WAIT_IN_PROGRESS = 2
|
||||
TAILING = 3
|
||||
JOB_COMPLETE = 4
|
||||
COMPLETE = 5
|
||||
|
||||
|
||||
# Position is a tuple that includes the last read timestamp and the number of items that were read
|
||||
# at that time. This is used to figure out which event to start with on the next read.
|
||||
Position = collections.namedtuple('Position', ['timestamp', 'skip'])
|
||||
|
||||
|
||||
def argmin(arr, f):
|
||||
"""Return the index, i, in arr that minimizes f(arr[i])"""
|
||||
m = None
|
||||
i = None
|
||||
for idx, item in enumerate(arr):
|
||||
if item is not None:
|
||||
if m is None or f(item) < m:
|
||||
m = f(item)
|
||||
i = idx
|
||||
return i
|
||||
|
||||
|
||||
def secondary_training_status_changed(current_job_description, prev_job_description):
|
||||
"""
|
||||
Returns true if training job's secondary status message has changed.
|
||||
|
||||
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
|
||||
:type current_job_description: dict
|
||||
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
|
||||
:type prev_job_description: dict
|
||||
|
||||
:return: Whether the secondary status message of a training job changed or not.
|
||||
"""
|
||||
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
|
||||
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
|
||||
return False
|
||||
|
||||
prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \
|
||||
if prev_job_description is not None else None
|
||||
|
||||
last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \
|
||||
if prev_job_secondary_status_transitions is not None \
|
||||
and len(prev_job_secondary_status_transitions) > 0 else ''
|
||||
|
||||
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
|
||||
|
||||
return message != last_message
|
||||
|
||||
|
||||
def secondary_training_status_message(job_description, prev_description):
|
||||
"""
|
||||
Returns a string contains start time and the secondary training job status message.
|
||||
|
||||
:param job_description: Returned response from DescribeTrainingJob call
|
||||
:type job_description: dict
|
||||
:param prev_description: Previous job description from DescribeTrainingJob call
|
||||
:type prev_description: dict
|
||||
|
||||
:return: Job status string to be printed.
|
||||
"""
|
||||
|
||||
if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
|
||||
or len(job_description.get('SecondaryStatusTransitions')) == 0:
|
||||
return ''
|
||||
|
||||
prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
|
||||
if prev_description is not None else None
|
||||
prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
|
||||
if prev_description_secondary_transitions is not None else 0
|
||||
current_transitions = job_description['SecondaryStatusTransitions']
|
||||
|
||||
transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
|
||||
current_transitions[prev_transitions_num - len(current_transitions):]
|
||||
|
||||
status_strs = []
|
||||
for transition in transitions_to_print:
|
||||
message = transition['StatusMessage']
|
||||
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
|
||||
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))
|
||||
|
||||
return '\n'.join(status_strs)
|
||||
|
||||
|
||||
class SageMakerHook(AwsHook):
|
||||
"""
|
||||
Interact with Amazon SageMaker.
|
||||
"""
|
||||
non_terminal_states = {'InProgress', 'Stopping'}
|
||||
endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating',
|
||||
'RollingBack', 'Deleting'}
|
||||
failed_states = {'Failed'}
|
||||
|
||||
def __init__(self,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
def tar_and_s3_upload(self, path, key, bucket):
|
||||
"""
|
||||
Tar the local file or directory and upload to s3
|
||||
|
||||
:param path: local file or directory
|
||||
:type path: str
|
||||
:param key: s3 key
|
||||
:type key: str
|
||||
:param bucket: s3 bucket
|
||||
:type bucket: str
|
||||
:return: None
|
||||
"""
|
||||
with tempfile.TemporaryFile() as temp_file:
|
||||
if os.path.isdir(path):
|
||||
files = [os.path.join(path, name) for name in os.listdir(path)]
|
||||
else:
|
||||
files = [path]
|
||||
with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file:
|
||||
for f in files:
|
||||
tar_file.add(f, arcname=os.path.basename(f))
|
||||
temp_file.seek(0)
|
||||
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
|
||||
|
||||
def configure_s3_resources(self, config):
|
||||
"""
|
||||
Extract the S3 operations from the configuration and execute them.
|
||||
|
||||
:param config: config of SageMaker operation
|
||||
:type config: dict
|
||||
:rtype: dict
|
||||
"""
|
||||
s3_operations = config.pop('S3Operations', None)
|
||||
|
||||
if s3_operations is not None:
|
||||
create_bucket_ops = s3_operations.get('S3CreateBucket', [])
|
||||
upload_ops = s3_operations.get('S3Upload', [])
|
||||
for op in create_bucket_ops:
|
||||
self.s3_hook.create_bucket(bucket_name=op['Bucket'])
|
||||
for op in upload_ops:
|
||||
if op['Tar']:
|
||||
self.tar_and_s3_upload(op['Path'], op['Key'],
|
||||
op['Bucket'])
|
||||
else:
|
||||
self.s3_hook.load_file(op['Path'], op['Key'],
|
||||
op['Bucket'])
|
||||
|
||||
def check_s3_url(self, s3url):
|
||||
"""
|
||||
Check if an S3 URL exists
|
||||
|
||||
:param s3url: S3 url
|
||||
:type s3url: str
|
||||
:rtype: bool
|
||||
"""
|
||||
bucket, key = S3Hook.parse_s3_url(s3url)
|
||||
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
|
||||
raise AirflowException(
|
||||
"The input S3 Bucket {} does not exist ".format(bucket))
|
||||
if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
|
||||
and not self.s3_hook.check_for_prefix(
|
||||
prefix=key, bucket_name=bucket, delimiter='/'):
|
||||
# check if s3 key exists in the case user provides a single file
|
||||
# or if s3 prefix exists in the case user provides multiple files in
|
||||
# a prefix
|
||||
raise AirflowException("The input S3 Key "
|
||||
"or Prefix {} does not exist in the Bucket {}"
|
||||
.format(s3url, bucket))
|
||||
return True
|
||||
|
||||
def check_training_config(self, training_config):
|
||||
"""
|
||||
Check if a training configuration is valid
|
||||
|
||||
:param training_config: training_config
|
||||
:type training_config: dict
|
||||
:return: None
|
||||
"""
|
||||
if "InputDataConfig" in training_config:
|
||||
for channel in training_config['InputDataConfig']:
|
||||
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
|
||||
|
||||
def check_tuning_config(self, tuning_config):
|
||||
"""
|
||||
Check if a tuning configuration is valid
|
||||
|
||||
:param tuning_config: tuning_config
|
||||
:type tuning_config: dict
|
||||
:return: None
|
||||
"""
|
||||
for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
|
||||
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Establish an AWS connection for SageMaker
|
||||
|
||||
:rtype: :py:class:`SageMaker.Client`
|
||||
"""
|
||||
return self.get_client_type('sagemaker')
|
||||
|
||||
def get_log_conn(self):
|
||||
"""
|
||||
This method is deprecated.
|
||||
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_conn` instead.
|
||||
"""
|
||||
warnings.warn("Method `get_log_conn` has been deprecated. "
|
||||
"Please use `airflow.contrib.hooks.AwsLogsHook.get_conn` instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2)
|
||||
|
||||
return self.logs_hook.get_conn()
|
||||
|
||||
def log_stream(self, log_group, stream_name, start_time=0, skip=0):
|
||||
"""
|
||||
This method is deprecated.
|
||||
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.
|
||||
"""
|
||||
warnings.warn("Method `log_stream` has been deprecated. "
|
||||
"Please use `airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2)
|
||||
|
||||
return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
|
||||
|
||||
def multi_stream_iter(self, log_group, streams, positions=None):
|
||||
"""
|
||||
Iterate over the available events coming from a set of log streams in a single log group
|
||||
interleaving the events from each stream so they're yielded in timestamp order.
|
||||
|
||||
:param log_group: The name of the log group.
|
||||
:type log_group: str
|
||||
:param streams: A list of the log stream names. The position of the stream in this list is
|
||||
the stream number.
|
||||
:type streams: list
|
||||
:param positions: A list of pairs of (timestamp, skip) which represents the last record
|
||||
read from each stream.
|
||||
:type positions: list
|
||||
:return: A tuple of (stream number, cloudwatch log event).
|
||||
"""
|
||||
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
|
||||
event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
|
||||
for s in streams]
|
||||
events = []
|
||||
for s in event_iters:
|
||||
if not s:
|
||||
events.append(None)
|
||||
continue
|
||||
try:
|
||||
events.append(next(s))
|
||||
except StopIteration:
|
||||
events.append(None)
|
||||
|
||||
while any(events):
|
||||
i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
|
||||
yield (i, events[i])
|
||||
try:
|
||||
events[i] = next(event_iters[i])
|
||||
except StopIteration:
|
||||
events[i] = None
|
||||
|
||||
def create_training_job(self, config, wait_for_completion=True, print_log=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create a training job
|
||||
|
||||
:param config: the config for training
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to training job creation
|
||||
"""
|
||||
|
||||
self.check_training_config(config)
|
||||
|
||||
response = self.get_conn().create_training_job(**config)
|
||||
if print_log:
|
||||
self.check_training_status_with_log(config['TrainingJobName'],
|
||||
self.non_terminal_states,
|
||||
self.failed_states,
|
||||
wait_for_completion,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
elif wait_for_completion:
|
||||
describe_response = self.check_status(config['TrainingJobName'],
|
||||
'TrainingJobStatus',
|
||||
self.describe_training_job,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
|
||||
billable_time = \
|
||||
(describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \
|
||||
describe_response['ResourceConfig']['InstanceCount']
|
||||
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
|
||||
|
||||
return response
|
||||
|
||||
def create_tuning_job(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create a tuning job
|
||||
|
||||
:param config: the config for tuning
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to tuning job creation
|
||||
"""
|
||||
|
||||
self.check_tuning_config(config)
|
||||
|
||||
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['HyperParameterTuningJobName'],
|
||||
'HyperParameterTuningJobStatus',
|
||||
self.describe_tuning_job,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
return response
|
||||
|
||||
def create_transform_job(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create a transform job
|
||||
|
||||
:param config: the config for transform job
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to transform job creation
|
||||
"""
|
||||
|
||||
self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
|
||||
|
||||
response = self.get_conn().create_transform_job(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['TransformJobName'],
|
||||
'TransformJobStatus',
|
||||
self.describe_transform_job,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
return response
|
||||
|
||||
def create_model(self, config):
|
||||
"""
|
||||
Create a model job
|
||||
|
||||
:param config: the config for model
|
||||
:type config: dict
|
||||
:return: A response to model creation
|
||||
"""
|
||||
|
||||
return self.get_conn().create_model(**config)
|
||||
|
||||
def create_endpoint_config(self, config):
|
||||
"""
|
||||
Create an endpoint config
|
||||
|
||||
:param config: the config for endpoint-config
|
||||
:type config: dict
|
||||
:return: A response to endpoint config creation
|
||||
"""
|
||||
|
||||
return self.get_conn().create_endpoint_config(**config)
|
||||
|
||||
def create_endpoint(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create an endpoint
|
||||
|
||||
:param config: the config for endpoint
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to endpoint creation
|
||||
"""
|
||||
|
||||
response = self.get_conn().create_endpoint(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['EndpointName'],
|
||||
'EndpointStatus',
|
||||
self.describe_endpoint,
|
||||
check_interval, max_ingestion_time,
|
||||
non_terminal_states=self.endpoint_non_terminal_states
|
||||
)
|
||||
return response
|
||||
|
||||
def update_endpoint(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Update an endpoint
|
||||
|
||||
:param config: the config for endpoint
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to endpoint update
|
||||
"""
|
||||
|
||||
response = self.get_conn().update_endpoint(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['EndpointName'],
|
||||
'EndpointStatus',
|
||||
self.describe_endpoint,
|
||||
check_interval, max_ingestion_time,
|
||||
non_terminal_states=self.endpoint_non_terminal_states
|
||||
)
|
||||
return response
|
||||
|
||||
def describe_training_job(self, name):
|
||||
"""
|
||||
Return the training job info associated with the name
|
||||
|
||||
:param name: the name of the training job
|
||||
:type name: str
|
||||
:return: A dict contains all the training job info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_training_job(TrainingJobName=name)
|
||||
|
||||
def describe_training_job_with_log(self, job_name, positions, stream_names,
|
||||
instance_count, state, last_description,
|
||||
last_describe_job_call):
|
||||
"""
|
||||
Return the training job info associated with job_name and print CloudWatch logs
|
||||
"""
|
||||
log_group = '/aws/sagemaker/TrainingJobs'
|
||||
|
||||
if len(stream_names) < instance_count:
|
||||
# Log streams are created whenever a container starts writing to stdout/err, so this list
|
||||
# may be dynamic until we have a stream for every instance.
|
||||
logs_conn = self.logs_hook.get_conn()
|
||||
try:
|
||||
streams = logs_conn.describe_log_streams(
|
||||
logGroupName=log_group,
|
||||
logStreamNamePrefix=job_name + '/',
|
||||
orderBy='LogStreamName',
|
||||
limit=instance_count
|
||||
)
|
||||
stream_names = [s['logStreamName'] for s in streams['logStreams']]
|
||||
positions.update([(s, Position(timestamp=0, skip=0))
|
||||
for s in stream_names if s not in positions])
|
||||
except logs_conn.exceptions.ResourceNotFoundException:
|
||||
# On the very first training job run on an account, there's no log group until
|
||||
# the container starts logging, so ignore any errors thrown about that
|
||||
pass
|
||||
|
||||
if len(stream_names) > 0:
|
||||
for idx, event in self.multi_stream_iter(log_group, stream_names, positions):
|
||||
self.log.info(event['message'])
|
||||
ts, count = positions[stream_names[idx]]
|
||||
if event['timestamp'] == ts:
|
||||
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
|
||||
else:
|
||||
positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1)
|
||||
|
||||
if state == LogState.COMPLETE:
|
||||
return state, last_description, last_describe_job_call
|
||||
|
||||
if state == LogState.JOB_COMPLETE:
|
||||
state = LogState.COMPLETE
|
||||
elif time.time() - last_describe_job_call >= 30:
|
||||
description = self.describe_training_job(job_name)
|
||||
last_describe_job_call = time.time()
|
||||
|
||||
if secondary_training_status_changed(description, last_description):
|
||||
self.log.info(secondary_training_status_message(description, last_description))
|
||||
last_description = description
|
||||
|
||||
status = description['TrainingJobStatus']
|
||||
|
||||
if status not in self.non_terminal_states:
|
||||
state = LogState.JOB_COMPLETE
|
||||
return state, last_description, last_describe_job_call
|
||||
|
||||
def describe_tuning_job(self, name):
|
||||
"""
|
||||
Return the tuning job info associated with the name
|
||||
|
||||
:param name: the name of the tuning job
|
||||
:type name: str
|
||||
:return: A dict contains all the tuning job info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
|
||||
|
||||
def describe_model(self, name):
|
||||
"""
|
||||
Return the SageMaker model info associated with the name
|
||||
|
||||
:param name: the name of the SageMaker model
|
||||
:type name: str
|
||||
:return: A dict contains all the model info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_model(ModelName=name)
|
||||
|
||||
def describe_transform_job(self, name):
|
||||
"""
|
||||
Return the transform job info associated with the name
|
||||
|
||||
:param name: the name of the transform job
|
||||
:type name: str
|
||||
:return: A dict contains all the transform job info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_transform_job(TransformJobName=name)
|
||||
|
||||
def describe_endpoint_config(self, name):
|
||||
"""
|
||||
Return the endpoint config info associated with the name
|
||||
|
||||
:param name: the name of the endpoint config
|
||||
:type name: str
|
||||
:return: A dict contains all the endpoint config info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
|
||||
|
||||
def describe_endpoint(self, name):
|
||||
"""
|
||||
:param name: the name of the endpoint
|
||||
:type name: str
|
||||
:return: A dict contains all the endpoint info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_endpoint(EndpointName=name)
|
||||
|
||||
def check_status(self, job_name, key,
|
||||
describe_function, check_interval,
|
||||
max_ingestion_time,
|
||||
non_terminal_states=None):
|
||||
"""
|
||||
Check status of a SageMaker job
|
||||
|
||||
:param job_name: name of the job to check status
|
||||
:type job_name: str
|
||||
:param key: the key of the response dict
|
||||
that points to the state
|
||||
:type key: str
|
||||
:param describe_function: the function used to retrieve the status
|
||||
:type describe_function: python callable
|
||||
:param args: the arguments for the function
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:param non_terminal_states: the set of nonterminal states
|
||||
:type non_terminal_states: set
|
||||
:return: response of describe call after job is done
|
||||
"""
|
||||
if not non_terminal_states:
|
||||
non_terminal_states = self.non_terminal_states
|
||||
|
||||
sec = 0
|
||||
running = True
|
||||
|
||||
while running:
|
||||
time.sleep(check_interval)
|
||||
sec = sec + check_interval
|
||||
|
||||
try:
|
||||
response = describe_function(job_name)
|
||||
status = response[key]
|
||||
self.log.info('Job still running for %s seconds... '
|
||||
'current status is %s' % (sec, status))
|
||||
except KeyError:
|
||||
raise AirflowException('Could not get status of the SageMaker job')
|
||||
except ClientError:
|
||||
raise AirflowException('AWS request failed, check logs for more info')
|
||||
|
||||
if status in non_terminal_states:
|
||||
running = True
|
||||
elif status in self.failed_states:
|
||||
raise AirflowException('SageMaker job failed because %s' % response['FailureReason'])
|
||||
else:
|
||||
running = False
|
||||
|
||||
if max_ingestion_time and sec > max_ingestion_time:
|
||||
# ensure that the job gets killed if the max ingestion time is exceeded
|
||||
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
|
||||
|
||||
self.log.info('SageMaker Job Compeleted')
|
||||
response = describe_function(job_name)
|
||||
return response
|
||||
|
||||
def check_training_status_with_log(self, job_name, non_terminal_states, failed_states,
|
||||
wait_for_completion, check_interval, max_ingestion_time):
|
||||
"""
|
||||
Display the logs for a given training job, optionally tailing them until the
|
||||
job is complete.
|
||||
|
||||
:param job_name: name of the training job to check status and display logs for
|
||||
:type job_name: str
|
||||
:param non_terminal_states: the set of non_terminal states
|
||||
:type non_terminal_states: set
|
||||
:param failed_states: the set of failed states
|
||||
:type failed_states: set
|
||||
:param wait_for_completion: Whether to keep looking for new log entries
|
||||
until the job completes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: The interval in seconds between polling for new log entries and job completion
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: None
|
||||
"""
|
||||
|
||||
sec = 0
|
||||
description = self.describe_training_job(job_name)
|
||||
self.log.info(secondary_training_status_message(description, None))
|
||||
instance_count = description['ResourceConfig']['InstanceCount']
|
||||
status = description['TrainingJobStatus']
|
||||
|
||||
stream_names = [] # The list of log streams
|
||||
positions = {} # The current position in each stream, map of stream name -> position
|
||||
|
||||
job_already_completed = status not in non_terminal_states
|
||||
|
||||
state = LogState.TAILING if wait_for_completion and not job_already_completed else LogState.COMPLETE
|
||||
|
||||
# The loop below implements a state machine that alternates between checking the job status and
|
||||
# reading whatever is available in the logs at this point. Note, that if we were called with
|
||||
# wait_for_completion == False, we never check the job status.
|
||||
#
|
||||
# If wait_for_completion == TRUE and job is not completed, the initial state is TAILING
|
||||
# If wait_for_completion == FALSE, the initial state is COMPLETE
|
||||
# (doesn't matter if the job really is complete).
|
||||
#
|
||||
# The state table:
|
||||
#
|
||||
# STATE ACTIONS CONDITION NEW STATE
|
||||
# ---------------- ---------------- ----------------- ----------------
|
||||
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
|
||||
# Else TAILING
|
||||
# JOB_COMPLETE Read logs, Pause Any COMPLETE
|
||||
# COMPLETE Read logs, Exit N/A
|
||||
#
|
||||
# Notes:
|
||||
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that
|
||||
# got to Cloudwatch after the job was marked complete.
|
||||
last_describe_job_call = time.time()
|
||||
last_description = description
|
||||
|
||||
while True:
|
||||
time.sleep(check_interval)
|
||||
sec = sec + check_interval
|
||||
|
||||
state, last_description, last_describe_job_call = \
|
||||
self.describe_training_job_with_log(job_name, positions, stream_names,
|
||||
instance_count, state, last_description,
|
||||
last_describe_job_call)
|
||||
if state == LogState.COMPLETE:
|
||||
break
|
||||
|
||||
if max_ingestion_time and sec > max_ingestion_time:
|
||||
# ensure that the job gets killed if the max ingestion time is exceeded
|
||||
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
|
||||
|
||||
if wait_for_completion:
|
||||
status = last_description['TrainingJobStatus']
|
||||
if status in failed_states:
|
||||
reason = last_description.get('FailureReason', '(No reason provided)')
|
||||
raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason))
|
||||
billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
|
||||
* instance_count
|
||||
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import ( # noqa
|
||||
LogState, Position, SageMakerHook, argmin, secondary_training_status_changed,
|
||||
secondary_training_status_message,
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,221 +16,33 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.typing_compat import Protocol
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSProtocol as NewECSProtocol # noqa
|
||||
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
class ECSProtocol(Protocol):
|
||||
def run_task(self, **kwargs):
|
||||
...
|
||||
|
||||
def get_waiter(self, x: str):
|
||||
...
|
||||
|
||||
def describe_tasks(self, cluster, tasks):
|
||||
...
|
||||
|
||||
def stop_task(self, cluster, task, reason: str):
|
||||
...
|
||||
|
||||
|
||||
class ECSOperator(BaseOperator):
|
||||
@runtime_checkable
|
||||
class ECSProtocol(NewECSProtocol, Protocol):
|
||||
"""
|
||||
Execute a task on AWS EC2 Container Service
|
||||
|
||||
:param task_definition: the task definition name on EC2 Container Service
|
||||
:type task_definition: str
|
||||
:param cluster: the cluster name on EC2 Container Service
|
||||
:type cluster: str
|
||||
:param overrides: the same parameter that boto3 will receive (templated):
|
||||
http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task
|
||||
:type overrides: dict
|
||||
:param aws_conn_id: connection id of AWS credentials / region name. If None,
|
||||
credential boto3 strategy will be used
|
||||
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
|
||||
:type aws_conn_id: str
|
||||
:param region_name: region name to use in AWS Hook.
|
||||
Override the region_name in connection (if provided)
|
||||
:type region_name: str
|
||||
:param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
|
||||
:type launch_type: str
|
||||
:param group: the name of the task group associated with the task
|
||||
:type group: str
|
||||
:param placement_constraints: an array of placement constraint objects to use for
|
||||
the task
|
||||
:type placement_constraints: list
|
||||
:param platform_version: the platform version on which your task is running
|
||||
:type platform_version: str
|
||||
:param network_configuration: the network configuration for the task
|
||||
:type network_configuration: dict
|
||||
:param tags: a dictionary of tags in the form of {'tagKey': 'tagValue'}.
|
||||
:type tags: dict
|
||||
:param awslogs_group: the CloudWatch group where your ECS container logs are stored.
|
||||
Only required if you want logs to be shown in the Airflow UI after your job has
|
||||
finished.
|
||||
:type awslogs_group: str
|
||||
:param awslogs_region: the region in which your CloudWatch logs are stored.
|
||||
If None, this is the same as the `region_name` parameter. If that is also None,
|
||||
this is the default AWS region based on your connection settings.
|
||||
:type awslogs_region: str
|
||||
:param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs.
|
||||
This is usually based on some custom name combined with the name of the container.
|
||||
Only required if you want logs to be shown in the Airflow UI after your job has
|
||||
finished.
|
||||
:type awslogs_stream_prefix: str
|
||||
This class is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs.ECSProtocol`.
|
||||
"""
|
||||
|
||||
ui_color = '#f0ede4'
|
||||
client = None # type: Optional[ECSProtocol]
|
||||
arn = None # type: Optional[str]
|
||||
template_fields = ('overrides',)
|
||||
# A Protocol cannot be instantiated
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, task_definition, cluster, overrides,
|
||||
aws_conn_id=None, region_name=None, launch_type='EC2',
|
||||
group=None, placement_constraints=None, platform_version='LATEST',
|
||||
network_configuration=None, tags=None, awslogs_group=None,
|
||||
awslogs_region=None, awslogs_stream_prefix=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.region_name = region_name
|
||||
self.task_definition = task_definition
|
||||
self.cluster = cluster
|
||||
self.overrides = overrides
|
||||
self.launch_type = launch_type
|
||||
self.group = group
|
||||
self.placement_constraints = placement_constraints
|
||||
self.platform_version = platform_version
|
||||
self.network_configuration = network_configuration
|
||||
|
||||
self.tags = tags
|
||||
self.awslogs_group = awslogs_group
|
||||
self.awslogs_stream_prefix = awslogs_stream_prefix
|
||||
self.awslogs_region = awslogs_region
|
||||
|
||||
if self.awslogs_region is None:
|
||||
self.awslogs_region = region_name
|
||||
|
||||
self.hook = self.get_hook()
|
||||
|
||||
def execute(self, context):
|
||||
self.log.info(
|
||||
'Running ECS Task - Task definition: %s - on cluster %s',
|
||||
self.task_definition, self.cluster
|
||||
def __new__(cls, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"""This class is deprecated.
|
||||
Please use `airflow.providers.amazon.aws.operators.ecs.ECSProtocol`.""",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.log.info('ECSOperator overrides: %s', self.overrides)
|
||||
|
||||
self.client = self.hook.get_client_type(
|
||||
'ecs',
|
||||
region_name=self.region_name
|
||||
)
|
||||
|
||||
run_opts = {
|
||||
'cluster': self.cluster,
|
||||
'taskDefinition': self.task_definition,
|
||||
'overrides': self.overrides,
|
||||
'startedBy': self.owner,
|
||||
'launchType': self.launch_type,
|
||||
}
|
||||
|
||||
if self.launch_type == 'FARGATE':
|
||||
run_opts['platformVersion'] = self.platform_version
|
||||
if self.group is not None:
|
||||
run_opts['group'] = self.group
|
||||
if self.placement_constraints is not None:
|
||||
run_opts['placementConstraints'] = self.placement_constraints
|
||||
if self.network_configuration is not None:
|
||||
run_opts['networkConfiguration'] = self.network_configuration
|
||||
if self.tags is not None:
|
||||
run_opts['tags'] = [{'key': k, 'value': v} for (k, v) in self.tags.items()]
|
||||
|
||||
response = self.client.run_task(**run_opts)
|
||||
|
||||
failures = response['failures']
|
||||
if len(failures) > 0:
|
||||
raise AirflowException(response)
|
||||
self.log.info('ECS Task started: %s', response)
|
||||
|
||||
self.arn = response['tasks'][0]['taskArn']
|
||||
self._wait_for_task_ended()
|
||||
|
||||
self._check_success_task()
|
||||
self.log.info('ECS Task has been successfully executed: %s', response)
|
||||
|
||||
def _wait_for_task_ended(self):
|
||||
waiter = self.client.get_waiter('tasks_stopped')
|
||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
|
||||
waiter.wait(
|
||||
cluster=self.cluster,
|
||||
tasks=[self.arn]
|
||||
)
|
||||
|
||||
def _check_success_task(self):
|
||||
response = self.client.describe_tasks(
|
||||
cluster=self.cluster,
|
||||
tasks=[self.arn]
|
||||
)
|
||||
self.log.info('ECS Task stopped, check status: %s', response)
|
||||
|
||||
# Get logs from CloudWatch if the awslogs log driver was used
|
||||
if self.awslogs_group and self.awslogs_stream_prefix:
|
||||
self.log.info('ECS Task logs output:')
|
||||
task_id = self.arn.split("/")[-1]
|
||||
stream_name = "{}/{}".format(self.awslogs_stream_prefix, task_id)
|
||||
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
|
||||
dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
|
||||
self.log.info("[{}] {}".format(dt.isoformat(), event['message']))
|
||||
|
||||
if len(response.get('failures', [])) > 0:
|
||||
raise AirflowException(response)
|
||||
|
||||
for task in response['tasks']:
|
||||
# This is a `stoppedReason` that indicates a task has not
|
||||
# successfully finished, but there is no other indication of failure
|
||||
# in the response.
|
||||
# See, https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html # noqa E501
|
||||
if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.',
|
||||
task.get('stoppedReason', '')):
|
||||
raise AirflowException(
|
||||
'The task was stopped because the host instance terminated: {}'.
|
||||
format(task.get('stoppedReason', '')))
|
||||
containers = task['containers']
|
||||
for container in containers:
|
||||
if container.get('lastStatus') == 'STOPPED' and \
|
||||
container['exitCode'] != 0:
|
||||
raise AirflowException(
|
||||
'This task is not in success state {}'.format(task))
|
||||
elif container.get('lastStatus') == 'PENDING':
|
||||
raise AirflowException('This task is still pending {}'.format(task))
|
||||
elif 'error' in container.get('reason', '').lower():
|
||||
raise AirflowException(
|
||||
'This containers encounter an error during launching : {}'.
|
||||
format(container.get('reason', '').lower()))
|
||||
|
||||
def get_hook(self):
|
||||
return AwsHook(
|
||||
aws_conn_id=self.aws_conn_id
|
||||
)
|
||||
|
||||
def get_logs_hook(self):
|
||||
return AwsLogsHook(
|
||||
aws_conn_id=self.aws_conn_id,
|
||||
region_name=self.awslogs_region
|
||||
)
|
||||
|
||||
def on_kill(self):
|
||||
response = self.client.stop_task(
|
||||
cluster=self.cluster,
|
||||
task=self.arn,
|
||||
reason='Task killed by the user')
|
||||
self.log.info(response)
|
||||
|
|
|
@ -16,75 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.contrib.hooks.emr_hook import EmrHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class EmrAddStepsOperator(BaseOperator):
|
||||
"""
|
||||
An operator that adds steps to an existing EMR job_flow.
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator # noqa
|
||||
|
||||
:param job_flow_id: id of the JobFlow to add steps to. (templated)
|
||||
:type job_flow_id: Optional[str]
|
||||
:param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing
|
||||
job_flow_id. will search for id of JobFlow with matching name in one of the states in
|
||||
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
|
||||
:type job_flow_name: Optional[str]
|
||||
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
|
||||
(templated)
|
||||
:type cluster_states: list
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
:param steps: boto3 style steps to be added to the jobflow. (templated)
|
||||
:type steps: list
|
||||
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
|
||||
:type do_xcom_push: bool
|
||||
"""
|
||||
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
job_flow_id=None,
|
||||
job_flow_name=None,
|
||||
cluster_states=None,
|
||||
aws_conn_id='aws_default',
|
||||
steps=None,
|
||||
*args, **kwargs):
|
||||
if kwargs.get('xcom_push') is not None:
|
||||
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
|
||||
if not ((job_flow_id is None) ^ (job_flow_name is None)):
|
||||
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
|
||||
super().__init__(*args, **kwargs)
|
||||
steps = steps or []
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.job_flow_id = job_flow_id
|
||||
self.job_flow_name = job_flow_name
|
||||
self.cluster_states = cluster_states
|
||||
self.steps = steps
|
||||
|
||||
def execute(self, context):
|
||||
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
emr = emr_hook.get_conn()
|
||||
|
||||
job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
|
||||
self.cluster_states)
|
||||
if not job_flow_id:
|
||||
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
|
||||
|
||||
if self.do_xcom_push:
|
||||
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
|
||||
|
||||
self.log.info('Adding steps to %s', job_flow_id)
|
||||
response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=self.steps)
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('Adding steps failed: %s' % response)
|
||||
else:
|
||||
self.log.info('Steps %s added to JobFlow', response['StepIds'])
|
||||
return response['StepIds']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,59 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.contrib.hooks.emr_hook import EmrHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class EmrCreateJobFlowOperator(BaseOperator):
|
||||
"""
|
||||
Creates an EMR JobFlow, reading the config from the EMR connection.
|
||||
A dictionary of JobFlow overrides can be passed that override
|
||||
the config from the connection.
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator # noqa
|
||||
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
:param emr_conn_id: emr connection to use
|
||||
:type emr_conn_id: str
|
||||
:param job_flow_overrides: boto3 style arguments to override
|
||||
emr_connection extra. (templated)
|
||||
:type job_flow_overrides: dict
|
||||
"""
|
||||
template_fields = ['job_flow_overrides']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
aws_conn_id='aws_default',
|
||||
emr_conn_id='emr_default',
|
||||
job_flow_overrides=None,
|
||||
region_name=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.emr_conn_id = emr_conn_id
|
||||
if job_flow_overrides is None:
|
||||
job_flow_overrides = {}
|
||||
self.job_flow_overrides = job_flow_overrides
|
||||
self.region_name = region_name
|
||||
|
||||
def execute(self, context):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id,
|
||||
emr_conn_id=self.emr_conn_id,
|
||||
region_name=self.region_name)
|
||||
|
||||
self.log.info(
|
||||
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',
|
||||
self.aws_conn_id, self.emr_conn_id
|
||||
)
|
||||
response = emr.create_job_flow(self.job_flow_overrides)
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('JobFlow creation failed: %s' % response)
|
||||
else:
|
||||
self.log.info('JobFlow with id %s created', response['JobFlowId'])
|
||||
return response['JobFlowId']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,42 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.contrib.hooks.emr_hook import EmrHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class EmrTerminateJobFlowOperator(BaseOperator):
|
||||
"""
|
||||
Operator to terminate EMR JobFlows.
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator # noqa
|
||||
|
||||
:param job_flow_id: id of the JobFlow to terminate. (templated)
|
||||
:type job_flow_id: str
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
template_fields = ['job_flow_id']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
job_flow_id,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_flow_id = job_flow_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def execute(self, context):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Terminating JobFlow %s', self.job_flow_id)
|
||||
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('JobFlow termination failed: %s' % response)
|
||||
else:
|
||||
self.log.info('JobFlow with id %s terminated', self.job_flow_id)
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,81 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`."""
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator # noqa
|
||||
|
||||
class S3CopyObjectOperator(BaseOperator):
|
||||
"""
|
||||
Creates a copy of an object that is already stored in S3.
|
||||
|
||||
Note: the S3 connection used here needs to have access to both
|
||||
source and destination bucket/key.
|
||||
|
||||
:param source_bucket_key: The key of the source object. (templated)
|
||||
|
||||
It can be either full s3:// style url or relative path from root level.
|
||||
|
||||
When it's specified as a full s3:// url, please omit source_bucket_name.
|
||||
:type source_bucket_key: str
|
||||
:param dest_bucket_key: The key of the object to copy to. (templated)
|
||||
|
||||
The convention to specify `dest_bucket_key` is the same as `source_bucket_key`.
|
||||
:type dest_bucket_key: str
|
||||
:param source_bucket_name: Name of the S3 bucket where the source object is in. (templated)
|
||||
|
||||
It should be omitted when `source_bucket_key` is provided as a full s3:// url.
|
||||
:type source_bucket_name: str
|
||||
:param dest_bucket_name: Name of the S3 bucket to where the object is copied. (templated)
|
||||
|
||||
It should be omitted when `dest_bucket_key` is provided as a full s3:// url.
|
||||
:type dest_bucket_name: str
|
||||
:param source_version_id: Version ID of the source object (OPTIONAL)
|
||||
:type source_version_id: str
|
||||
:param aws_conn_id: Connection id of the S3 connection to use
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
|
||||
You can provide the following values:
|
||||
|
||||
- False: do not validate SSL certificates. SSL will still be used,
|
||||
but SSL certificates will not be
|
||||
verified.
|
||||
- path/to/cert/bundle.pem: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
|
||||
template_fields = ('source_bucket_key', 'dest_bucket_key',
|
||||
'source_bucket_name', 'dest_bucket_name')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
source_bucket_key,
|
||||
dest_bucket_key,
|
||||
source_bucket_name=None,
|
||||
dest_bucket_name=None,
|
||||
source_version_id=None,
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.source_bucket_key = source_bucket_key
|
||||
self.dest_bucket_key = dest_bucket_key
|
||||
self.source_bucket_name = source_bucket_name
|
||||
self.dest_bucket_name = dest_bucket_name
|
||||
self.source_version_id = source_version_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def execute(self, context):
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key,
|
||||
self.source_bucket_name, self.dest_bucket_name,
|
||||
self.source_version_id)
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,72 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`."""
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator # noqa
|
||||
|
||||
class S3DeleteObjectsOperator(BaseOperator):
|
||||
"""
|
||||
To enable users to delete single object or multiple objects from
|
||||
a bucket using a single HTTP request.
|
||||
|
||||
Users may specify up to 1000 keys to delete.
|
||||
|
||||
:param bucket: Name of the bucket in which you are going to delete object(s). (templated)
|
||||
:type bucket: str
|
||||
:param keys: The key(s) to delete from S3 bucket. (templated)
|
||||
|
||||
When ``keys`` is a string, it's supposed to be the key name of
|
||||
the single object to delete.
|
||||
|
||||
When ``keys`` is a list, it's supposed to be the list of the
|
||||
keys to delete.
|
||||
|
||||
You may specify up to 1000 keys.
|
||||
:type keys: str or list
|
||||
:param aws_conn_id: Connection id of the S3 connection to use
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used,
|
||||
but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
|
||||
template_fields = ('keys', 'bucket')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
bucket,
|
||||
keys,
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.bucket = bucket
|
||||
self.keys = keys
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def execute(self, context):
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
|
||||
response = s3_hook.delete_objects(bucket=self.bucket, keys=self.keys)
|
||||
|
||||
deleted_keys = [x['Key'] for x in response.get("Deleted", [])]
|
||||
self.log.info("Deleted: %s", deleted_keys)
|
||||
|
||||
if "Errors" in response:
|
||||
errors_keys = [x['Key'] for x in response.get("Errors", [])]
|
||||
raise AirflowException("Errors when deleting: {}".format(errors_keys))
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,84 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`."""
|
||||
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator # noqa
|
||||
|
||||
|
||||
class S3ListOperator(BaseOperator):
|
||||
"""
|
||||
List all objects from the bucket with the given string prefix in name.
|
||||
|
||||
This operator returns a python list with the name of objects which can be
|
||||
used by `xcom` in the downstream task.
|
||||
|
||||
:param bucket: The S3 bucket where to find the objects. (templated)
|
||||
:type bucket: str
|
||||
:param prefix: Prefix string to filters the objects whose name begin with
|
||||
such prefix. (templated)
|
||||
:type prefix: str
|
||||
:param delimiter: the delimiter marks key hierarchy. (templated)
|
||||
:type delimiter: str
|
||||
:param aws_conn_id: The connection ID to use when connecting to S3 storage.
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
|
||||
|
||||
**Example**:
|
||||
The following operator would list all the files
|
||||
(excluding subfolders) from the S3
|
||||
``customers/2018/04/`` key in the ``data`` bucket. ::
|
||||
|
||||
s3_file = S3ListOperator(
|
||||
task_id='list_3s_files',
|
||||
bucket='data',
|
||||
prefix='customers/2018/04/',
|
||||
delimiter='/',
|
||||
aws_conn_id='aws_customers_conn'
|
||||
)
|
||||
"""
|
||||
template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str]
|
||||
ui_color = '#ffd700'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
bucket,
|
||||
prefix='',
|
||||
delimiter='',
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.bucket = bucket
|
||||
self.prefix = prefix
|
||||
self.delimiter = delimiter
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def execute(self, context):
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
|
||||
self.log.info(
|
||||
'Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)',
|
||||
self.bucket, self.prefix, self.delimiter
|
||||
)
|
||||
|
||||
return hook.list_keys(
|
||||
bucket_name=self.bucket,
|
||||
prefix=self.prefix,
|
||||
delimiter=self.delimiter)
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -19,10 +19,10 @@
|
|||
import warnings
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from airflow.contrib.operators.s3_list_operator import S3ListOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.gcp.hooks.gcs import GCSHook, _parse_gcs_url
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
|
|
|
@ -16,86 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`."""
|
||||
|
||||
import json
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator # noqa
|
||||
|
||||
|
||||
class SageMakerBaseOperator(BaseOperator):
|
||||
"""
|
||||
This is the base operator for all SageMaker operators.
|
||||
|
||||
:param config: The configuration necessary to start a training job (templated)
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
|
||||
template_fields = ['config']
|
||||
template_ext = ()
|
||||
ui_color = '#ededed'
|
||||
|
||||
integer_fields = [] # type: Iterable[Iterable[str]]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.config = config
|
||||
self.hook = None
|
||||
|
||||
def parse_integer(self, config, field):
|
||||
if len(field) == 1:
|
||||
if isinstance(config, list):
|
||||
for sub_config in config:
|
||||
self.parse_integer(sub_config, field)
|
||||
return
|
||||
head = field[0]
|
||||
if head in config:
|
||||
config[head] = int(config[head])
|
||||
return
|
||||
|
||||
if isinstance(config, list):
|
||||
for sub_config in config:
|
||||
self.parse_integer(sub_config, field)
|
||||
return
|
||||
|
||||
head, tail = field[0], field[1:]
|
||||
if head in config:
|
||||
self.parse_integer(config[head], tail)
|
||||
return
|
||||
|
||||
def parse_config_integers(self):
|
||||
# Parse the integer fields of training config to integers
|
||||
# in case the config is rendered by Jinja and all fields are str
|
||||
for field in self.integer_fields:
|
||||
self.parse_integer(self.config, field)
|
||||
|
||||
def expand_role(self):
|
||||
pass
|
||||
|
||||
def preprocess_config(self):
|
||||
self.log.info(
|
||||
'Preprocessing the config and doing required s3_operations'
|
||||
)
|
||||
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.hook.configure_s3_resources(self.config)
|
||||
self.parse_config_integers()
|
||||
self.expand_role()
|
||||
|
||||
self.log.info(
|
||||
'After preprocessing the config is:\n {}'.format(
|
||||
json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
|
||||
)
|
||||
|
||||
def execute(self, context):
|
||||
raise NotImplementedError('Please implement execute() in sub class!')
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,51 +16,20 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
This module is deprecated.
|
||||
Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.
|
||||
"""
|
||||
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import ( # noqa
|
||||
SageMakerEndpointConfigOperator,
|
||||
)
|
||||
|
||||
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
|
||||
|
||||
"""
|
||||
Create a SageMaker endpoint config.
|
||||
|
||||
This operator returns The ARN of the endpoint config created in Amazon SageMaker
|
||||
|
||||
:param config: The configuration necessary to create an endpoint config.
|
||||
|
||||
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
|
||||
integer_fields = [
|
||||
['ProductionVariants', 'InitialInstanceCount']
|
||||
]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.config = config
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
|
||||
response = self.hook.create_endpoint_config(self.config)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException(
|
||||
'Sagemaker endpoint config creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'EndpointConfig': self.hook.describe_endpoint_config(
|
||||
self.config['EndpointConfigName']
|
||||
)
|
||||
}
|
||||
warnings.warn(
|
||||
"This module is deprecated. "
|
||||
"Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,135 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator # noqa
|
||||
|
||||
class SageMakerEndpointOperator(SageMakerBaseOperator):
|
||||
|
||||
"""
|
||||
Create a SageMaker endpoint.
|
||||
|
||||
This operator returns The ARN of the endpoint created in Amazon SageMaker
|
||||
|
||||
:param config:
|
||||
The configuration necessary to create an endpoint.
|
||||
|
||||
If you need to create a SageMaker endpoint based on an existed
|
||||
SageMaker model and an existed SageMaker endpoint config::
|
||||
|
||||
config = endpoint_configuration;
|
||||
|
||||
If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
|
||||
|
||||
config = {
|
||||
'Model': model_configuration,
|
||||
'EndpointConfig': endpoint_config_configuration,
|
||||
'Endpoint': endpoint_configuration
|
||||
}
|
||||
|
||||
For details of the configuration parameter of model_configuration see
|
||||
:py:meth:`SageMaker.Client.create_model`
|
||||
|
||||
For details of the configuration parameter of endpoint_config_configuration see
|
||||
:py:meth:`SageMaker.Client.create_endpoint_config`
|
||||
|
||||
For details of the configuration parameter of endpoint_configuration see
|
||||
:py:meth:`SageMaker.Client.create_endpoint`
|
||||
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
|
||||
waits before polling the status of the endpoint creation.
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
|
||||
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
|
||||
:type max_ingestion_time: int
|
||||
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
|
||||
:type operation: str
|
||||
"""
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
operation='create',
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.config = config
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
self.operation = operation.lower()
|
||||
if self.operation not in ['create', 'update']:
|
||||
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
|
||||
self.create_integer_fields()
|
||||
|
||||
def create_integer_fields(self):
|
||||
if 'EndpointConfig' in self.config:
|
||||
self.integer_fields = [
|
||||
['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
|
||||
]
|
||||
|
||||
def expand_role(self):
|
||||
if 'Model' not in self.config:
|
||||
return
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
config = self.config['Model']
|
||||
if 'ExecutionRoleArn' in config:
|
||||
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
model_info = self.config.get('Model')
|
||||
endpoint_config_info = self.config.get('EndpointConfig')
|
||||
endpoint_info = self.config.get('Endpoint', self.config)
|
||||
|
||||
if model_info:
|
||||
self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
|
||||
self.hook.create_model(model_info)
|
||||
|
||||
if endpoint_config_info:
|
||||
self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
|
||||
self.hook.create_endpoint_config(endpoint_config_info)
|
||||
|
||||
if self.operation == 'create':
|
||||
sagemaker_operation = self.hook.create_endpoint
|
||||
log_str = 'Creating'
|
||||
elif self.operation == 'update':
|
||||
sagemaker_operation = self.hook.update_endpoint
|
||||
log_str = 'Updating'
|
||||
else:
|
||||
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
|
||||
|
||||
self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
|
||||
|
||||
response = sagemaker_operation(
|
||||
endpoint_info,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time
|
||||
)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException(
|
||||
'Sagemaker endpoint creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'EndpointConfig': self.hook.describe_endpoint_config(
|
||||
endpoint_info['EndpointConfigName']
|
||||
),
|
||||
'Endpoint': self.hook.describe_endpoint(
|
||||
endpoint_info['EndpointName']
|
||||
)
|
||||
}
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,52 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator # noqa
|
||||
|
||||
class SageMakerModelOperator(SageMakerBaseOperator):
|
||||
|
||||
"""
|
||||
Create a SageMaker model.
|
||||
|
||||
This operator returns The ARN of the model created in Amazon SageMaker
|
||||
|
||||
:param config: The configuration necessary to create a model.
|
||||
|
||||
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.config = config
|
||||
|
||||
def expand_role(self):
|
||||
if 'ExecutionRoleArn' in self.config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
|
||||
response = self.hook.create_model(self.config)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker model creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Model': self.hook.describe_model(
|
||||
self.config['ModelName']
|
||||
)
|
||||
}
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,83 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator # noqa
|
||||
|
||||
class SageMakerTrainingOperator(SageMakerBaseOperator):
|
||||
"""
|
||||
Initiate a SageMaker training job.
|
||||
|
||||
This operator returns The ARN of the training job created in Amazon SageMaker.
|
||||
|
||||
:param config: The configuration necessary to start a training job (templated).
|
||||
|
||||
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
|
||||
that the operation waits to check the status of the training job.
|
||||
:type wait_for_completion: bool
|
||||
:param print_log: if the operator should print the cloudwatch log during training
|
||||
:type print_log: bool
|
||||
:param check_interval: if wait is set to be true, this is the time interval
|
||||
in seconds which the operator will check the status of the training job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, the operation fails if the training job
|
||||
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
|
||||
the operation does not timeout.
|
||||
:type max_ingestion_time: int
|
||||
"""
|
||||
|
||||
integer_fields = [
|
||||
['ResourceConfig', 'InstanceCount'],
|
||||
['ResourceConfig', 'VolumeSizeInGB'],
|
||||
['StoppingCondition', 'MaxRuntimeInSeconds']
|
||||
]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
print_log=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.print_log = print_log
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
|
||||
def expand_role(self):
|
||||
if 'RoleArn' in self.config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info('Creating SageMaker Training Job %s.', self.config['TrainingJobName'])
|
||||
|
||||
response = self.hook.create_training_job(
|
||||
self.config,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
print_log=self.print_log,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time
|
||||
)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker Training Job creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Training': self.hook.describe_training_job(
|
||||
self.config['TrainingJobName']
|
||||
)
|
||||
}
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,109 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator # noqa
|
||||
|
||||
class SageMakerTransformOperator(SageMakerBaseOperator):
|
||||
"""
|
||||
Initiate a SageMaker transform job.
|
||||
|
||||
This operator returns The ARN of the model created in Amazon SageMaker.
|
||||
|
||||
:param config: The configuration necessary to start a transform job (templated).
|
||||
|
||||
If you need to create a SageMaker transform job based on an existed SageMaker model::
|
||||
|
||||
config = transform_config
|
||||
|
||||
If you need to create both SageMaker model and SageMaker Transform job::
|
||||
|
||||
config = {
|
||||
'Model': model_config,
|
||||
'Transform': transform_config
|
||||
}
|
||||
|
||||
For details of the configuration parameter of transform_config see
|
||||
:py:meth:`SageMaker.Client.create_transform_job`
|
||||
|
||||
For details of the configuration parameter of model_config, See:
|
||||
:py:meth:`SageMaker.Client.create_model`
|
||||
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: Set to True to wait until the transform job finishes.
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: If wait is set to True, the time interval, in seconds,
|
||||
that this operation waits to check the status of the transform job.
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, the operation fails
|
||||
if the transform job doesn't finish within max_ingestion_time seconds. If you
|
||||
set this parameter to None, the operation does not timeout.
|
||||
:type max_ingestion_time: int
|
||||
"""
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
self.config = config
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
self.create_integer_fields()
|
||||
|
||||
def create_integer_fields(self):
|
||||
self.integer_fields = [
|
||||
['Transform', 'TransformResources', 'InstanceCount'],
|
||||
['Transform', 'MaxConcurrentTransforms'],
|
||||
['Transform', 'MaxPayloadInMB']
|
||||
]
|
||||
if 'Transform' not in self.config:
|
||||
for field in self.integer_fields:
|
||||
field.pop(0)
|
||||
|
||||
def expand_role(self):
|
||||
if 'Model' not in self.config:
|
||||
return
|
||||
config = self.config['Model']
|
||||
if 'ExecutionRoleArn' in config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
model_config = self.config.get('Model')
|
||||
transform_config = self.config.get('Transform', self.config)
|
||||
|
||||
if model_config:
|
||||
self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
|
||||
self.hook.create_model(model_config)
|
||||
|
||||
self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
|
||||
response = self.hook.create_transform_job(
|
||||
transform_config,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker transform Job creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Model': self.hook.describe_model(
|
||||
transform_config['ModelName']
|
||||
),
|
||||
'Transform': self.hook.describe_transform_job(
|
||||
transform_config['TransformJobName']
|
||||
)
|
||||
}
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,84 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`."""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator # noqa
|
||||
|
||||
class SageMakerTuningOperator(SageMakerBaseOperator):
|
||||
"""
|
||||
Initiate a SageMaker hyperparameter tuning job.
|
||||
|
||||
This operator returns The ARN of the tuning job created in Amazon SageMaker.
|
||||
|
||||
:param config: The configuration necessary to start a tuning job (templated).
|
||||
|
||||
For details of the configuration parameter see
|
||||
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: Set to True to wait until the tuning job finishes.
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: If wait is set to True, the time interval, in seconds,
|
||||
that this operation waits to check the status of the tuning job.
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, the operation fails
|
||||
if the tuning job doesn't finish within max_ingestion_time seconds. If you
|
||||
set this parameter to None, the operation does not timeout.
|
||||
:type max_ingestion_time: int
|
||||
"""
|
||||
|
||||
integer_fields = [
|
||||
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
|
||||
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
|
||||
['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
|
||||
['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
|
||||
['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds']
|
||||
]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
self.config = config
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
|
||||
def expand_role(self):
|
||||
if 'TrainingJobDefinition' in self.config:
|
||||
config = self.config['TrainingJobDefinition']
|
||||
if 'RoleArn' in config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
config['RoleArn'] = hook.expand_role(config['RoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info(
|
||||
'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
|
||||
)
|
||||
|
||||
response = self.hook.create_tuning_job(
|
||||
self.config,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time
|
||||
)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker Tuning Job creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Tuning': self.hook.describe_tuning_job(
|
||||
self.config['HyperParameterTuningJobName']
|
||||
)
|
||||
}
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,78 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`."""
|
||||
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import AwsGlueCatalogPartitionSensor # noqa
|
||||
|
||||
class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a partition to show up in AWS Glue Catalog.
|
||||
|
||||
:param table_name: The name of the table to wait for, supports the dot
|
||||
notation (my_database.my_table)
|
||||
:type table_name: str
|
||||
:param expression: The partition clause to wait for. This is passed as
|
||||
is to the AWS Glue Catalog API's get_partitions function,
|
||||
and supports SQL like notation as in ``ds='2015-01-01'
|
||||
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
|
||||
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
|
||||
#aws-glue-api-catalog-partitions-GetPartitions
|
||||
:type expression: str
|
||||
:param aws_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type aws_conn_id: str
|
||||
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
|
||||
if not specified.
|
||||
:type region_name: str
|
||||
:param database_name: The name of the catalog database where the partitions reside.
|
||||
:type database_name: str
|
||||
:param poke_interval: Time in seconds that the job should wait in
|
||||
between each tries
|
||||
:type poke_interval: int
|
||||
"""
|
||||
template_fields = ('database_name', 'table_name', 'expression',)
|
||||
ui_color = '#C5CAE9'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
table_name, expression="ds='{{ ds }}'",
|
||||
aws_conn_id='aws_default',
|
||||
region_name=None,
|
||||
database_name='default',
|
||||
poke_interval=60 * 3,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
poke_interval=poke_interval, *args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.region_name = region_name
|
||||
self.table_name = table_name
|
||||
self.expression = expression
|
||||
self.database_name = database_name
|
||||
|
||||
def poke(self, context):
|
||||
"""
|
||||
Checks for existence of the partition in the AWS Glue Catalog table
|
||||
"""
|
||||
if '.' in self.table_name:
|
||||
self.database_name, self.table_name = self.table_name.split('.')
|
||||
self.log.info(
|
||||
'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression
|
||||
)
|
||||
|
||||
return self.get_hook().check_for_partition(
|
||||
self.database_name, self.table_name, self.expression)
|
||||
|
||||
def get_hook(self):
|
||||
"""
|
||||
Gets the AwsGlueCatalogHook
|
||||
"""
|
||||
if not hasattr(self, 'hook'):
|
||||
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
|
||||
self.hook = AwsGlueCatalogHook(
|
||||
aws_conn_id=self.aws_conn_id,
|
||||
region_name=self.region_name)
|
||||
|
||||
return self.hook
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,45 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class EmrBaseSensor(BaseSensorOperator):
|
||||
"""
|
||||
Contains general sensor behavior for EMR.
|
||||
Subclasses should implement get_emr_response() and state_from_response() methods.
|
||||
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE constants.
|
||||
"""
|
||||
ui_color = '#66c3ff'
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor # noqa
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
response = self.get_emr_response()
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
self.log.info('Bad HTTP response: %s', response)
|
||||
return False
|
||||
|
||||
state = self.state_from_response(response)
|
||||
self.log.info('Job flow currently %s', state)
|
||||
|
||||
if state in self.NON_TERMINAL_STATES:
|
||||
return False
|
||||
|
||||
if state in self.FAILED_STATE:
|
||||
final_message = 'EMR job failed'
|
||||
failure_message = self.failure_message_from_response(response)
|
||||
if failure_message:
|
||||
final_message += ' ' + failure_message
|
||||
raise AirflowException(final_message)
|
||||
|
||||
return True
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,48 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.contrib.hooks.emr_hook import EmrHook
|
||||
from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class EmrJobFlowSensor(EmrBaseSensor):
|
||||
"""
|
||||
Asks for the state of the JobFlow until it reaches a terminal state.
|
||||
If it fails the sensor errors, failing the task.
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor # noqa
|
||||
|
||||
:param job_flow_id: job_flow_id to check the state of
|
||||
:type job_flow_id: str
|
||||
"""
|
||||
|
||||
NON_TERMINAL_STATES = ['STARTING', 'BOOTSTRAPPING', 'RUNNING',
|
||||
'WAITING', 'TERMINATING']
|
||||
FAILED_STATE = ['TERMINATED_WITH_ERRORS']
|
||||
template_fields = ['job_flow_id']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_flow_id,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_flow_id = job_flow_id
|
||||
|
||||
def get_emr_response(self):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Poking cluster %s', self.job_flow_id)
|
||||
return emr.describe_cluster(ClusterId=self.job_flow_id)
|
||||
|
||||
@staticmethod
|
||||
def state_from_response(response):
|
||||
return response['Cluster']['Status']['State']
|
||||
|
||||
@staticmethod
|
||||
def failure_message_from_response(response):
|
||||
state_change_reason = response['Cluster']['Status'].get('StateChangeReason')
|
||||
if state_change_reason:
|
||||
return 'for code: {} with message {}'.format(state_change_reason.get('Code', 'No code'),
|
||||
state_change_reason.get('Message', 'Unknown'))
|
||||
return None
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,52 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.contrib.hooks.emr_hook import EmrHook
|
||||
from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class EmrStepSensor(EmrBaseSensor):
|
||||
"""
|
||||
Asks for the state of the step until it reaches a terminal state.
|
||||
If it fails the sensor errors, failing the task.
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor # noqa
|
||||
|
||||
:param job_flow_id: job_flow_id which contains the step check the state of
|
||||
:type job_flow_id: str
|
||||
:param step_id: step to check the state of
|
||||
:type step_id: str
|
||||
"""
|
||||
|
||||
NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE', 'CANCEL_PENDING']
|
||||
FAILED_STATE = ['CANCELLED', 'FAILED', 'INTERRUPTED']
|
||||
template_fields = ['job_flow_id', 'step_id']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_flow_id,
|
||||
step_id,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_flow_id = job_flow_id
|
||||
self.step_id = step_id
|
||||
|
||||
def get_emr_response(self):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id)
|
||||
return emr.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
|
||||
|
||||
@staticmethod
|
||||
def state_from_response(response):
|
||||
return response['Step']['Status']['State']
|
||||
|
||||
@staticmethod
|
||||
def failure_message_from_response(response):
|
||||
fail_details = response['Step']['Status'].get('FailureDetails')
|
||||
if fail_details:
|
||||
return 'for reason {} with message {} and log file {}'.format(fail_details.get('Reason'),
|
||||
fail_details.get('Message'),
|
||||
fail_details.get('LogFile'))
|
||||
return None
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,59 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`."""
|
||||
|
||||
import warnings
|
||||
|
||||
class SageMakerBaseSensor(BaseSensorOperator):
|
||||
"""
|
||||
Contains general sensor behavior for SageMaker.
|
||||
Subclasses should implement get_sagemaker_response()
|
||||
and state_from_response() methods.
|
||||
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
|
||||
"""
|
||||
ui_color = '#ededed'
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor # noqa
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
response = self.get_sagemaker_response()
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
self.log.info('Bad HTTP response: %s', response)
|
||||
return False
|
||||
|
||||
state = self.state_from_response(response)
|
||||
|
||||
self.log.info('Job currently %s', state)
|
||||
|
||||
if state in self.non_terminal_states():
|
||||
return False
|
||||
|
||||
if state in self.failed_states():
|
||||
failed_reason = self.get_failed_reason_from_response(response)
|
||||
raise AirflowException('Sagemaker job failed for the following reason: %s'
|
||||
% failed_reason)
|
||||
return True
|
||||
|
||||
def non_terminal_states(self):
|
||||
raise NotImplementedError('Please implement non_terminal_states() in subclass')
|
||||
|
||||
def failed_states(self):
|
||||
raise NotImplementedError('Please implement failed_states() in subclass')
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return 'Unknown'
|
||||
|
||||
def state_from_response(self, response):
|
||||
raise NotImplementedError('Please implement state_from_response() in subclass')
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,46 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`."""
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor # noqa
|
||||
|
||||
class SageMakerEndpointSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the endpoint state until it reaches a terminal state.
|
||||
If it fails the sensor errors, the task fails.
|
||||
|
||||
:param job_name: job_name of the endpoint instance to check the state of
|
||||
:type job_name: str
|
||||
"""
|
||||
|
||||
template_fields = ['endpoint_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
endpoint_name,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.endpoint_non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
|
||||
return sagemaker.describe_endpoint(self.endpoint_name)
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['EndpointStatus']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,87 +16,16 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_training`."""
|
||||
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import LogState, SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_training import ( # noqa
|
||||
SageMakerHook, SageMakerTrainingSensor,
|
||||
)
|
||||
|
||||
|
||||
class SageMakerTrainingSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the training state until it reaches a terminal state.
|
||||
If it fails the sensor errors, failing the task.
|
||||
|
||||
:param job_name: name of the SageMaker training job to check the state of
|
||||
:type job_name: str
|
||||
:param print_log: if the operator should print the cloudwatch log
|
||||
:type print_log: bool
|
||||
"""
|
||||
|
||||
template_fields = ['job_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_name,
|
||||
print_log=True,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_name = job_name
|
||||
self.print_log = print_log
|
||||
self.positions = {}
|
||||
self.stream_names = []
|
||||
self.instance_count = None
|
||||
self.state = None
|
||||
self.last_description = None
|
||||
self.last_describe_job_call = None
|
||||
self.log_resource_inited = False
|
||||
|
||||
def init_log_resource(self, hook):
|
||||
description = hook.describe_training_job(self.job_name)
|
||||
self.instance_count = description['ResourceConfig']['InstanceCount']
|
||||
|
||||
status = description['TrainingJobStatus']
|
||||
job_already_completed = status not in self.non_terminal_states()
|
||||
self.state = LogState.TAILING if not job_already_completed else LogState.COMPLETE
|
||||
self.last_description = description
|
||||
self.last_describe_job_call = time.time()
|
||||
self.log_resource_inited = True
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
if self.print_log:
|
||||
if not self.log_resource_inited:
|
||||
self.init_log_resource(sagemaker_hook)
|
||||
self.state, self.last_description, self.last_describe_job_call = \
|
||||
sagemaker_hook.describe_training_job_with_log(self.job_name,
|
||||
self.positions, self.stream_names,
|
||||
self.instance_count, self.state,
|
||||
self.last_description,
|
||||
self.last_describe_job_call)
|
||||
else:
|
||||
self.last_description = sagemaker_hook.describe_training_job(self.job_name)
|
||||
|
||||
status = self.state_from_response(self.last_description)
|
||||
if status not in self.non_terminal_states() and status not in self.failed_states():
|
||||
billable_time = \
|
||||
(self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
|
||||
self.last_description['ResourceConfig']['InstanceCount']
|
||||
self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)
|
||||
|
||||
return self.last_description
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['TrainingJobStatus']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_training`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,47 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`."""
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor # noqa
|
||||
|
||||
class SageMakerTransformSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the transform state until it reaches a terminal state.
|
||||
The sensor will error if the job errors, throwing a AirflowException
|
||||
containing the failure reason.
|
||||
|
||||
:param job_name: job_name of the transform job instance to check the state of
|
||||
:type job_name: str
|
||||
"""
|
||||
|
||||
template_fields = ['job_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_name,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_name = job_name
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
|
||||
return sagemaker.describe_transform_job(self.job_name)
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['TransformJobStatus']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,47 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`."""
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor # noqa
|
||||
|
||||
class SageMakerTuningSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the tuning state until it reaches a terminal state.
|
||||
The sensor will error if the job errors, throwing a AirflowException
|
||||
containing the failure reason.
|
||||
|
||||
:param job_name: job_name of the tuning instance to check the state of
|
||||
:type job_name: str
|
||||
"""
|
||||
|
||||
template_fields = ['job_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_name,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_name = job_name
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
|
||||
return sagemaker.describe_tuning_job(self.job_name)
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['HyperParameterTuningJobStatus']
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,155 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`."""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, Union
|
||||
import warnings
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.operators.s3_file_transform import S3FileTransformOperator # noqa
|
||||
|
||||
|
||||
class S3FileTransformOperator(BaseOperator):
|
||||
"""
|
||||
Copies data from a source S3 location to a temporary location on the
|
||||
local filesystem. Runs a transformation on this file as specified by
|
||||
the transformation script and uploads the output to a destination S3
|
||||
location.
|
||||
|
||||
The locations of the source and the destination files in the local
|
||||
filesystem is provided as an first and second arguments to the
|
||||
transformation script. The transformation script is expected to read the
|
||||
data from source, transform it and write the output to the local
|
||||
destination file. The operator then takes over control and uploads the
|
||||
local destination file to S3.
|
||||
|
||||
S3 Select is also available to filter the source contents. Users can
|
||||
omit the transformation script if S3 Select expression is specified.
|
||||
|
||||
:param source_s3_key: The key to be retrieved from S3. (templated)
|
||||
:type source_s3_key: str
|
||||
:param dest_s3_key: The key to be written from S3. (templated)
|
||||
:type dest_s3_key: str
|
||||
:param transform_script: location of the executable transformation script
|
||||
:type transform_script: str
|
||||
:param select_expression: S3 Select expression
|
||||
:type select_expression: str
|
||||
:param source_aws_conn_id: source s3 connection
|
||||
:type source_aws_conn_id: str
|
||||
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
|
||||
This is also applicable to ``dest_verify``.
|
||||
:type source_verify: bool or str
|
||||
:param dest_aws_conn_id: destination s3 connection
|
||||
:type dest_aws_conn_id: str
|
||||
:param dest_verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
See: ``source_verify``
|
||||
:type dest_verify: bool or str
|
||||
:param replace: Replace dest S3 key if it already exists
|
||||
:type replace: bool
|
||||
"""
|
||||
|
||||
template_fields = ('source_s3_key', 'dest_s3_key')
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
source_s3_key: str,
|
||||
dest_s3_key: str,
|
||||
transform_script: Optional[str] = None,
|
||||
select_expression=None,
|
||||
source_aws_conn_id: str = 'aws_default',
|
||||
source_verify: Optional[Union[bool, str]] = None,
|
||||
dest_aws_conn_id: str = 'aws_default',
|
||||
dest_verify: Optional[Union[bool, str]] = None,
|
||||
replace: bool = False,
|
||||
*args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.source_s3_key = source_s3_key
|
||||
self.source_aws_conn_id = source_aws_conn_id
|
||||
self.source_verify = source_verify
|
||||
self.dest_s3_key = dest_s3_key
|
||||
self.dest_aws_conn_id = dest_aws_conn_id
|
||||
self.dest_verify = dest_verify
|
||||
self.replace = replace
|
||||
self.transform_script = transform_script
|
||||
self.select_expression = select_expression
|
||||
self.output_encoding = sys.getdefaultencoding()
|
||||
|
||||
def execute(self, context):
|
||||
if self.transform_script is None and self.select_expression is None:
|
||||
raise AirflowException(
|
||||
"Either transform_script or select_expression must be specified")
|
||||
|
||||
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify)
|
||||
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
|
||||
|
||||
self.log.info("Downloading source S3 file %s", self.source_s3_key)
|
||||
if not source_s3.check_for_key(self.source_s3_key):
|
||||
raise AirflowException(
|
||||
"The source key {0} does not exist".format(self.source_s3_key))
|
||||
source_s3_key_object = source_s3.get_key(self.source_s3_key)
|
||||
|
||||
with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest:
|
||||
self.log.info(
|
||||
"Dumping S3 file %s contents to local file %s",
|
||||
self.source_s3_key, f_source.name
|
||||
)
|
||||
|
||||
if self.select_expression is not None:
|
||||
content = source_s3.select_key(
|
||||
key=self.source_s3_key,
|
||||
expression=self.select_expression
|
||||
)
|
||||
f_source.write(content.encode("utf-8"))
|
||||
else:
|
||||
source_s3_key_object.download_fileobj(Fileobj=f_source)
|
||||
f_source.flush()
|
||||
|
||||
if self.transform_script is not None:
|
||||
process = subprocess.Popen(
|
||||
[self.transform_script, f_source.name, f_dest.name],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
close_fds=True
|
||||
)
|
||||
|
||||
self.log.info("Output:")
|
||||
for line in iter(process.stdout.readline, b''):
|
||||
self.log.info(line.decode(self.output_encoding).rstrip())
|
||||
|
||||
process.wait()
|
||||
|
||||
if process.returncode > 0:
|
||||
raise AirflowException(
|
||||
"Transform script failed: {0}".format(process.returncode)
|
||||
)
|
||||
else:
|
||||
self.log.info(
|
||||
"Transform script successful. Output temporarily located at %s",
|
||||
f_dest.name
|
||||
)
|
||||
|
||||
self.log.info("Uploading transformed file to S3")
|
||||
f_dest.flush()
|
||||
dest_s3.load_file(
|
||||
filename=f_dest.name,
|
||||
key=self.dest_s3_key,
|
||||
replace=self.replace
|
||||
)
|
||||
self.log.info("Upload successful")
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
|
||||
|
||||
class EmrHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS EMR. emr_conn_id is only necessary for using the
|
||||
create_job_flow method.
|
||||
"""
|
||||
|
||||
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
|
||||
self.emr_conn_id = emr_conn_id
|
||||
self.region_name = region_name
|
||||
self.conn = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
if not self.conn:
|
||||
self.conn = self.get_client_type('emr', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
|
||||
conn = self.get_conn()
|
||||
|
||||
response = conn.list_clusters(
|
||||
ClusterStates=cluster_states
|
||||
)
|
||||
|
||||
matching_clusters = list(
|
||||
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
|
||||
)
|
||||
|
||||
if len(matching_clusters) == 1:
|
||||
cluster_id = matching_clusters[0]['Id']
|
||||
self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id)
|
||||
return cluster_id
|
||||
elif len(matching_clusters) > 1:
|
||||
raise AirflowException('More than one cluster found for name %s', emr_cluster_name)
|
||||
else:
|
||||
self.log.info('No cluster found for name %s', emr_cluster_name)
|
||||
return None
|
||||
|
||||
def create_job_flow(self, job_flow_overrides):
|
||||
"""
|
||||
Creates a job flow using the config from the EMR connection.
|
||||
Keys of the json extra hash may have the arguments of the boto3
|
||||
run_job_flow method.
|
||||
Overrides for this config may be passed as the job_flow_overrides.
|
||||
"""
|
||||
|
||||
if not self.emr_conn_id:
|
||||
raise AirflowException('emr_conn_id must be present to use create_job_flow')
|
||||
|
||||
emr_conn = self.get_connection(self.emr_conn_id)
|
||||
|
||||
config = emr_conn.extra_dejson.copy()
|
||||
config.update(job_flow_overrides)
|
||||
|
||||
response = self.get_conn().run_job_flow(**config)
|
||||
|
||||
return response
|
|
@ -0,0 +1,152 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
This module contains AWS Glue Catalog Hook
|
||||
"""
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
|
||||
|
||||
class AwsGlueCatalogHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS Glue Catalog
|
||||
|
||||
:param aws_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type aws_conn_id: str
|
||||
:param region_name: aws region name (example: us-east-1)
|
||||
:type region_name: str
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
aws_conn_id='aws_default',
|
||||
region_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.region_name = region_name
|
||||
self.conn = None
|
||||
super().__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Returns glue connection object.
|
||||
"""
|
||||
self.conn = self.get_client_type('glue', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def get_partitions(self,
|
||||
database_name,
|
||||
table_name,
|
||||
expression='',
|
||||
page_size=None,
|
||||
max_items=None):
|
||||
"""
|
||||
Retrieves the partition values for a table.
|
||||
|
||||
:param database_name: The name of the catalog database where the partitions reside.
|
||||
:type database_name: str
|
||||
:param table_name: The name of the partitions' table.
|
||||
:type table_name: str
|
||||
:param expression: An expression filtering the partitions to be returned.
|
||||
Please see official AWS documentation for further information.
|
||||
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
|
||||
:type expression: str
|
||||
:param page_size: pagination size
|
||||
:type page_size: int
|
||||
:param max_items: maximum items to return
|
||||
:type max_items: int
|
||||
:return: set of partition values where each value is a tuple since
|
||||
a partition may be composed of multiple columns. For example:
|
||||
``{('2018-01-01','1'), ('2018-01-01','2')}``
|
||||
"""
|
||||
config = {
|
||||
'PageSize': page_size,
|
||||
'MaxItems': max_items,
|
||||
}
|
||||
|
||||
paginator = self.get_conn().get_paginator('get_partitions')
|
||||
response = paginator.paginate(
|
||||
DatabaseName=database_name,
|
||||
TableName=table_name,
|
||||
Expression=expression,
|
||||
PaginationConfig=config
|
||||
)
|
||||
|
||||
partitions = set()
|
||||
for page in response:
|
||||
for partition in page['Partitions']:
|
||||
partitions.add(tuple(partition['Values']))
|
||||
|
||||
return partitions
|
||||
|
||||
def check_for_partition(self, database_name, table_name, expression):
|
||||
"""
|
||||
Checks whether a partition exists
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table @partition belongs to
|
||||
:type table_name: str
|
||||
:expression: Expression that matches the partitions to check for
|
||||
(eg `a = 'b' AND c = 'd'`)
|
||||
:type expression: str
|
||||
:rtype: bool
|
||||
|
||||
>>> hook = AwsGlueCatalogHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
|
||||
True
|
||||
"""
|
||||
partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
|
||||
|
||||
return bool(partitions)
|
||||
|
||||
def get_table(self, database_name, table_name):
|
||||
"""
|
||||
Get the information of the table
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table
|
||||
:type table_name: str
|
||||
:rtype: dict
|
||||
|
||||
>>> hook = AwsGlueCatalogHook()
|
||||
>>> r = hook.get_table('db', 'table_foo')
|
||||
>>> r['Name'] = 'table_foo'
|
||||
"""
|
||||
|
||||
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name)
|
||||
|
||||
return result['Table']
|
||||
|
||||
def get_table_location(self, database_name, table_name):
|
||||
"""
|
||||
Get the physical location of the table
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table
|
||||
:type table_name: str
|
||||
:return: str
|
||||
"""
|
||||
|
||||
table = self.get_table(database_name, table_name)
|
||||
|
||||
return table['StorageDescriptor']['Location']
|
|
@ -0,0 +1,102 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
This module contains a hook (AwsLogsHook) with some very basic
|
||||
functionality for interacting with AWS CloudWatch.
|
||||
"""
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
|
||||
|
||||
class AwsLogsHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS CloudWatch Logs
|
||||
|
||||
:param region_name: AWS Region Name (example: us-west-2)
|
||||
:type region_name: str
|
||||
"""
|
||||
|
||||
def __init__(self, region_name=None, *args, **kwargs):
|
||||
self.region_name = region_name
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Establish an AWS connection for retrieving logs.
|
||||
|
||||
:rtype: CloudWatchLogs.Client
|
||||
"""
|
||||
return self.get_client_type('logs', region_name=self.region_name)
|
||||
|
||||
def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start_from_head=True):
|
||||
"""
|
||||
A generator for log items in a single stream. This will yield all the
|
||||
items that are available at the current moment.
|
||||
|
||||
:param log_group: The name of the log group.
|
||||
:type log_group: str
|
||||
:param log_stream_name: The name of the specific stream.
|
||||
:type log_stream_name: str
|
||||
:param start_time: The time stamp value to start reading the logs from (default: 0).
|
||||
:type start_time: int
|
||||
:param skip: The number of log entries to skip at the start (default: 0).
|
||||
This is for when there are multiple entries at the same timestamp.
|
||||
:type skip: int
|
||||
:param start_from_head: whether to start from the beginning (True) of the log or
|
||||
at the end of the log (False).
|
||||
:type start_from_head: bool
|
||||
:rtype: dict
|
||||
:return: | A CloudWatch log event with the following key-value pairs:
|
||||
| 'timestamp' (int): The time in milliseconds of the event.
|
||||
| 'message' (str): The log event data.
|
||||
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
|
||||
"""
|
||||
|
||||
next_token = None
|
||||
|
||||
event_count = 1
|
||||
while event_count > 0:
|
||||
if next_token is not None:
|
||||
token_arg = {'nextToken': next_token}
|
||||
else:
|
||||
token_arg = {}
|
||||
|
||||
response = self.get_conn().get_log_events(logGroupName=log_group,
|
||||
logStreamName=log_stream_name,
|
||||
startTime=start_time,
|
||||
startFromHead=start_from_head,
|
||||
**token_arg)
|
||||
|
||||
events = response['events']
|
||||
event_count = len(events)
|
||||
|
||||
if event_count > skip:
|
||||
events = events[skip:]
|
||||
skip = 0
|
||||
else:
|
||||
skip = skip - event_count
|
||||
events = []
|
||||
|
||||
yield from events
|
||||
|
||||
if 'nextForwardToken' in response:
|
||||
next_token = response['nextForwardToken']
|
||||
else:
|
||||
return
|
|
@ -0,0 +1,741 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import collections
|
||||
import os
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils import timezone
|
||||
|
||||
|
||||
class LogState:
|
||||
STARTING = 1
|
||||
WAIT_IN_PROGRESS = 2
|
||||
TAILING = 3
|
||||
JOB_COMPLETE = 4
|
||||
COMPLETE = 5
|
||||
|
||||
|
||||
# Position is a tuple that includes the last read timestamp and the number of items that were read
|
||||
# at that time. This is used to figure out which event to start with on the next read.
|
||||
Position = collections.namedtuple('Position', ['timestamp', 'skip'])
|
||||
|
||||
|
||||
def argmin(arr, f):
|
||||
"""Return the index, i, in arr that minimizes f(arr[i])"""
|
||||
m = None
|
||||
i = None
|
||||
for idx, item in enumerate(arr):
|
||||
if item is not None:
|
||||
if m is None or f(item) < m:
|
||||
m = f(item)
|
||||
i = idx
|
||||
return i
|
||||
|
||||
|
||||
def secondary_training_status_changed(current_job_description, prev_job_description):
|
||||
"""
|
||||
Returns true if training job's secondary status message has changed.
|
||||
|
||||
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
|
||||
:type current_job_description: dict
|
||||
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
|
||||
:type prev_job_description: dict
|
||||
|
||||
:return: Whether the secondary status message of a training job changed or not.
|
||||
"""
|
||||
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
|
||||
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
|
||||
return False
|
||||
|
||||
prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \
|
||||
if prev_job_description is not None else None
|
||||
|
||||
last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \
|
||||
if prev_job_secondary_status_transitions is not None \
|
||||
and len(prev_job_secondary_status_transitions) > 0 else ''
|
||||
|
||||
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
|
||||
|
||||
return message != last_message
|
||||
|
||||
|
||||
def secondary_training_status_message(job_description, prev_description):
|
||||
"""
|
||||
Returns a string contains start time and the secondary training job status message.
|
||||
|
||||
:param job_description: Returned response from DescribeTrainingJob call
|
||||
:type job_description: dict
|
||||
:param prev_description: Previous job description from DescribeTrainingJob call
|
||||
:type prev_description: dict
|
||||
|
||||
:return: Job status string to be printed.
|
||||
"""
|
||||
|
||||
if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
|
||||
or len(job_description.get('SecondaryStatusTransitions')) == 0:
|
||||
return ''
|
||||
|
||||
prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
|
||||
if prev_description is not None else None
|
||||
prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
|
||||
if prev_description_secondary_transitions is not None else 0
|
||||
current_transitions = job_description['SecondaryStatusTransitions']
|
||||
|
||||
transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
|
||||
current_transitions[prev_transitions_num - len(current_transitions):]
|
||||
|
||||
status_strs = []
|
||||
for transition in transitions_to_print:
|
||||
message = transition['StatusMessage']
|
||||
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
|
||||
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))
|
||||
|
||||
return '\n'.join(status_strs)
|
||||
|
||||
|
||||
class SageMakerHook(AwsHook):
|
||||
"""
|
||||
Interact with Amazon SageMaker.
|
||||
"""
|
||||
non_terminal_states = {'InProgress', 'Stopping'}
|
||||
endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating',
|
||||
'RollingBack', 'Deleting'}
|
||||
failed_states = {'Failed'}
|
||||
|
||||
def __init__(self,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
def tar_and_s3_upload(self, path, key, bucket):
|
||||
"""
|
||||
Tar the local file or directory and upload to s3
|
||||
|
||||
:param path: local file or directory
|
||||
:type path: str
|
||||
:param key: s3 key
|
||||
:type key: str
|
||||
:param bucket: s3 bucket
|
||||
:type bucket: str
|
||||
:return: None
|
||||
"""
|
||||
with tempfile.TemporaryFile() as temp_file:
|
||||
if os.path.isdir(path):
|
||||
files = [os.path.join(path, name) for name in os.listdir(path)]
|
||||
else:
|
||||
files = [path]
|
||||
with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file:
|
||||
for f in files:
|
||||
tar_file.add(f, arcname=os.path.basename(f))
|
||||
temp_file.seek(0)
|
||||
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
|
||||
|
||||
def configure_s3_resources(self, config):
|
||||
"""
|
||||
Extract the S3 operations from the configuration and execute them.
|
||||
|
||||
:param config: config of SageMaker operation
|
||||
:type config: dict
|
||||
:rtype: dict
|
||||
"""
|
||||
s3_operations = config.pop('S3Operations', None)
|
||||
|
||||
if s3_operations is not None:
|
||||
create_bucket_ops = s3_operations.get('S3CreateBucket', [])
|
||||
upload_ops = s3_operations.get('S3Upload', [])
|
||||
for op in create_bucket_ops:
|
||||
self.s3_hook.create_bucket(bucket_name=op['Bucket'])
|
||||
for op in upload_ops:
|
||||
if op['Tar']:
|
||||
self.tar_and_s3_upload(op['Path'], op['Key'],
|
||||
op['Bucket'])
|
||||
else:
|
||||
self.s3_hook.load_file(op['Path'], op['Key'],
|
||||
op['Bucket'])
|
||||
|
||||
def check_s3_url(self, s3url):
|
||||
"""
|
||||
Check if an S3 URL exists
|
||||
|
||||
:param s3url: S3 url
|
||||
:type s3url: str
|
||||
:rtype: bool
|
||||
"""
|
||||
bucket, key = S3Hook.parse_s3_url(s3url)
|
||||
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
|
||||
raise AirflowException(
|
||||
"The input S3 Bucket {} does not exist ".format(bucket))
|
||||
if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
|
||||
and not self.s3_hook.check_for_prefix(
|
||||
prefix=key, bucket_name=bucket, delimiter='/'):
|
||||
# check if s3 key exists in the case user provides a single file
|
||||
# or if s3 prefix exists in the case user provides multiple files in
|
||||
# a prefix
|
||||
raise AirflowException("The input S3 Key "
|
||||
"or Prefix {} does not exist in the Bucket {}"
|
||||
.format(s3url, bucket))
|
||||
return True
|
||||
|
||||
def check_training_config(self, training_config):
|
||||
"""
|
||||
Check if a training configuration is valid
|
||||
|
||||
:param training_config: training_config
|
||||
:type training_config: dict
|
||||
:return: None
|
||||
"""
|
||||
if "InputDataConfig" in training_config:
|
||||
for channel in training_config['InputDataConfig']:
|
||||
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
|
||||
|
||||
def check_tuning_config(self, tuning_config):
|
||||
"""
|
||||
Check if a tuning configuration is valid
|
||||
|
||||
:param tuning_config: tuning_config
|
||||
:type tuning_config: dict
|
||||
:return: None
|
||||
"""
|
||||
for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
|
||||
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Establish an AWS connection for SageMaker
|
||||
|
||||
:rtype: :py:class:`SageMaker.Client`
|
||||
"""
|
||||
return self.get_client_type('sagemaker')
|
||||
|
||||
def get_log_conn(self):
|
||||
"""
|
||||
This method is deprecated.
|
||||
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_conn` instead.
|
||||
"""
|
||||
warnings.warn("Method `get_log_conn` has been deprecated. "
|
||||
"Please use `airflow.contrib.hooks.AwsLogsHook.get_conn` instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2)
|
||||
|
||||
return self.logs_hook.get_conn()
|
||||
|
||||
def log_stream(self, log_group, stream_name, start_time=0, skip=0):
|
||||
"""
|
||||
This method is deprecated.
|
||||
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.
|
||||
"""
|
||||
warnings.warn("Method `log_stream` has been deprecated. "
|
||||
"Please use `airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2)
|
||||
|
||||
return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
|
||||
|
||||
def multi_stream_iter(self, log_group, streams, positions=None):
|
||||
"""
|
||||
Iterate over the available events coming from a set of log streams in a single log group
|
||||
interleaving the events from each stream so they're yielded in timestamp order.
|
||||
|
||||
:param log_group: The name of the log group.
|
||||
:type log_group: str
|
||||
:param streams: A list of the log stream names. The position of the stream in this list is
|
||||
the stream number.
|
||||
:type streams: list
|
||||
:param positions: A list of pairs of (timestamp, skip) which represents the last record
|
||||
read from each stream.
|
||||
:type positions: list
|
||||
:return: A tuple of (stream number, cloudwatch log event).
|
||||
"""
|
||||
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
|
||||
event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
|
||||
for s in streams]
|
||||
events = []
|
||||
for s in event_iters:
|
||||
if not s:
|
||||
events.append(None)
|
||||
continue
|
||||
try:
|
||||
events.append(next(s))
|
||||
except StopIteration:
|
||||
events.append(None)
|
||||
|
||||
while any(events):
|
||||
i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
|
||||
yield (i, events[i])
|
||||
try:
|
||||
events[i] = next(event_iters[i])
|
||||
except StopIteration:
|
||||
events[i] = None
|
||||
|
||||
def create_training_job(self, config, wait_for_completion=True, print_log=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create a training job
|
||||
|
||||
:param config: the config for training
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to training job creation
|
||||
"""
|
||||
|
||||
self.check_training_config(config)
|
||||
|
||||
response = self.get_conn().create_training_job(**config)
|
||||
if print_log:
|
||||
self.check_training_status_with_log(config['TrainingJobName'],
|
||||
self.non_terminal_states,
|
||||
self.failed_states,
|
||||
wait_for_completion,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
elif wait_for_completion:
|
||||
describe_response = self.check_status(config['TrainingJobName'],
|
||||
'TrainingJobStatus',
|
||||
self.describe_training_job,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
|
||||
billable_time = \
|
||||
(describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \
|
||||
describe_response['ResourceConfig']['InstanceCount']
|
||||
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
|
||||
|
||||
return response
|
||||
|
||||
def create_tuning_job(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create a tuning job
|
||||
|
||||
:param config: the config for tuning
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to tuning job creation
|
||||
"""
|
||||
|
||||
self.check_tuning_config(config)
|
||||
|
||||
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['HyperParameterTuningJobName'],
|
||||
'HyperParameterTuningJobStatus',
|
||||
self.describe_tuning_job,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
return response
|
||||
|
||||
def create_transform_job(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create a transform job
|
||||
|
||||
:param config: the config for transform job
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to transform job creation
|
||||
"""
|
||||
|
||||
self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
|
||||
|
||||
response = self.get_conn().create_transform_job(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['TransformJobName'],
|
||||
'TransformJobStatus',
|
||||
self.describe_transform_job,
|
||||
check_interval, max_ingestion_time
|
||||
)
|
||||
return response
|
||||
|
||||
def create_model(self, config):
|
||||
"""
|
||||
Create a model job
|
||||
|
||||
:param config: the config for model
|
||||
:type config: dict
|
||||
:return: A response to model creation
|
||||
"""
|
||||
|
||||
return self.get_conn().create_model(**config)
|
||||
|
||||
def create_endpoint_config(self, config):
|
||||
"""
|
||||
Create an endpoint config
|
||||
|
||||
:param config: the config for endpoint-config
|
||||
:type config: dict
|
||||
:return: A response to endpoint config creation
|
||||
"""
|
||||
|
||||
return self.get_conn().create_endpoint_config(**config)
|
||||
|
||||
def create_endpoint(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Create an endpoint
|
||||
|
||||
:param config: the config for endpoint
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to endpoint creation
|
||||
"""
|
||||
|
||||
response = self.get_conn().create_endpoint(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['EndpointName'],
|
||||
'EndpointStatus',
|
||||
self.describe_endpoint,
|
||||
check_interval, max_ingestion_time,
|
||||
non_terminal_states=self.endpoint_non_terminal_states
|
||||
)
|
||||
return response
|
||||
|
||||
def update_endpoint(self, config, wait_for_completion=True,
|
||||
check_interval=30, max_ingestion_time=None):
|
||||
"""
|
||||
Update an endpoint
|
||||
|
||||
:param config: the config for endpoint
|
||||
:type config: dict
|
||||
:param wait_for_completion: if the program should keep running until job finishes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: A response to endpoint update
|
||||
"""
|
||||
|
||||
response = self.get_conn().update_endpoint(**config)
|
||||
if wait_for_completion:
|
||||
self.check_status(config['EndpointName'],
|
||||
'EndpointStatus',
|
||||
self.describe_endpoint,
|
||||
check_interval, max_ingestion_time,
|
||||
non_terminal_states=self.endpoint_non_terminal_states
|
||||
)
|
||||
return response
|
||||
|
||||
def describe_training_job(self, name):
|
||||
"""
|
||||
Return the training job info associated with the name
|
||||
|
||||
:param name: the name of the training job
|
||||
:type name: str
|
||||
:return: A dict contains all the training job info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_training_job(TrainingJobName=name)
|
||||
|
||||
def describe_training_job_with_log(self, job_name, positions, stream_names,
|
||||
instance_count, state, last_description,
|
||||
last_describe_job_call):
|
||||
"""
|
||||
Return the training job info associated with job_name and print CloudWatch logs
|
||||
"""
|
||||
log_group = '/aws/sagemaker/TrainingJobs'
|
||||
|
||||
if len(stream_names) < instance_count:
|
||||
# Log streams are created whenever a container starts writing to stdout/err, so this list
|
||||
# may be dynamic until we have a stream for every instance.
|
||||
logs_conn = self.logs_hook.get_conn()
|
||||
try:
|
||||
streams = logs_conn.describe_log_streams(
|
||||
logGroupName=log_group,
|
||||
logStreamNamePrefix=job_name + '/',
|
||||
orderBy='LogStreamName',
|
||||
limit=instance_count
|
||||
)
|
||||
stream_names = [s['logStreamName'] for s in streams['logStreams']]
|
||||
positions.update([(s, Position(timestamp=0, skip=0))
|
||||
for s in stream_names if s not in positions])
|
||||
except logs_conn.exceptions.ResourceNotFoundException:
|
||||
# On the very first training job run on an account, there's no log group until
|
||||
# the container starts logging, so ignore any errors thrown about that
|
||||
pass
|
||||
|
||||
if len(stream_names) > 0:
|
||||
for idx, event in self.multi_stream_iter(log_group, stream_names, positions):
|
||||
self.log.info(event['message'])
|
||||
ts, count = positions[stream_names[idx]]
|
||||
if event['timestamp'] == ts:
|
||||
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
|
||||
else:
|
||||
positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1)
|
||||
|
||||
if state == LogState.COMPLETE:
|
||||
return state, last_description, last_describe_job_call
|
||||
|
||||
if state == LogState.JOB_COMPLETE:
|
||||
state = LogState.COMPLETE
|
||||
elif time.time() - last_describe_job_call >= 30:
|
||||
description = self.describe_training_job(job_name)
|
||||
last_describe_job_call = time.time()
|
||||
|
||||
if secondary_training_status_changed(description, last_description):
|
||||
self.log.info(secondary_training_status_message(description, last_description))
|
||||
last_description = description
|
||||
|
||||
status = description['TrainingJobStatus']
|
||||
|
||||
if status not in self.non_terminal_states:
|
||||
state = LogState.JOB_COMPLETE
|
||||
return state, last_description, last_describe_job_call
|
||||
|
||||
def describe_tuning_job(self, name):
|
||||
"""
|
||||
Return the tuning job info associated with the name
|
||||
|
||||
:param name: the name of the tuning job
|
||||
:type name: str
|
||||
:return: A dict contains all the tuning job info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
|
||||
|
||||
def describe_model(self, name):
|
||||
"""
|
||||
Return the SageMaker model info associated with the name
|
||||
|
||||
:param name: the name of the SageMaker model
|
||||
:type name: str
|
||||
:return: A dict contains all the model info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_model(ModelName=name)
|
||||
|
||||
def describe_transform_job(self, name):
|
||||
"""
|
||||
Return the transform job info associated with the name
|
||||
|
||||
:param name: the name of the transform job
|
||||
:type name: str
|
||||
:return: A dict contains all the transform job info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_transform_job(TransformJobName=name)
|
||||
|
||||
def describe_endpoint_config(self, name):
|
||||
"""
|
||||
Return the endpoint config info associated with the name
|
||||
|
||||
:param name: the name of the endpoint config
|
||||
:type name: str
|
||||
:return: A dict contains all the endpoint config info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
|
||||
|
||||
def describe_endpoint(self, name):
|
||||
"""
|
||||
:param name: the name of the endpoint
|
||||
:type name: str
|
||||
:return: A dict contains all the endpoint info
|
||||
"""
|
||||
|
||||
return self.get_conn().describe_endpoint(EndpointName=name)
|
||||
|
||||
def check_status(self, job_name, key,
|
||||
describe_function, check_interval,
|
||||
max_ingestion_time,
|
||||
non_terminal_states=None):
|
||||
"""
|
||||
Check status of a SageMaker job
|
||||
|
||||
:param job_name: name of the job to check status
|
||||
:type job_name: str
|
||||
:param key: the key of the response dict
|
||||
that points to the state
|
||||
:type key: str
|
||||
:param describe_function: the function used to retrieve the status
|
||||
:type describe_function: python callable
|
||||
:param args: the arguments for the function
|
||||
:param check_interval: the time interval in seconds which the operator
|
||||
will check the status of any SageMaker job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:param non_terminal_states: the set of nonterminal states
|
||||
:type non_terminal_states: set
|
||||
:return: response of describe call after job is done
|
||||
"""
|
||||
if not non_terminal_states:
|
||||
non_terminal_states = self.non_terminal_states
|
||||
|
||||
sec = 0
|
||||
running = True
|
||||
|
||||
while running:
|
||||
time.sleep(check_interval)
|
||||
sec = sec + check_interval
|
||||
|
||||
try:
|
||||
response = describe_function(job_name)
|
||||
status = response[key]
|
||||
self.log.info('Job still running for %s seconds... '
|
||||
'current status is %s' % (sec, status))
|
||||
except KeyError:
|
||||
raise AirflowException('Could not get status of the SageMaker job')
|
||||
except ClientError:
|
||||
raise AirflowException('AWS request failed, check logs for more info')
|
||||
|
||||
if status in non_terminal_states:
|
||||
running = True
|
||||
elif status in self.failed_states:
|
||||
raise AirflowException('SageMaker job failed because %s' % response['FailureReason'])
|
||||
else:
|
||||
running = False
|
||||
|
||||
if max_ingestion_time and sec > max_ingestion_time:
|
||||
# ensure that the job gets killed if the max ingestion time is exceeded
|
||||
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
|
||||
|
||||
self.log.info('SageMaker Job Compeleted')
|
||||
response = describe_function(job_name)
|
||||
return response
|
||||
|
||||
def check_training_status_with_log(self, job_name, non_terminal_states, failed_states,
|
||||
wait_for_completion, check_interval, max_ingestion_time):
|
||||
"""
|
||||
Display the logs for a given training job, optionally tailing them until the
|
||||
job is complete.
|
||||
|
||||
:param job_name: name of the training job to check status and display logs for
|
||||
:type job_name: str
|
||||
:param non_terminal_states: the set of non_terminal states
|
||||
:type non_terminal_states: set
|
||||
:param failed_states: the set of failed states
|
||||
:type failed_states: set
|
||||
:param wait_for_completion: Whether to keep looking for new log entries
|
||||
until the job completes
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: The interval in seconds between polling for new log entries and job completion
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: the maximum ingestion time in seconds. Any
|
||||
SageMaker jobs that run longer than this will fail. Setting this to
|
||||
None implies no timeout for any SageMaker job.
|
||||
:type max_ingestion_time: int
|
||||
:return: None
|
||||
"""
|
||||
|
||||
sec = 0
|
||||
description = self.describe_training_job(job_name)
|
||||
self.log.info(secondary_training_status_message(description, None))
|
||||
instance_count = description['ResourceConfig']['InstanceCount']
|
||||
status = description['TrainingJobStatus']
|
||||
|
||||
stream_names = [] # The list of log streams
|
||||
positions = {} # The current position in each stream, map of stream name -> position
|
||||
|
||||
job_already_completed = status not in non_terminal_states
|
||||
|
||||
state = LogState.TAILING if wait_for_completion and not job_already_completed else LogState.COMPLETE
|
||||
|
||||
# The loop below implements a state machine that alternates between checking the job status and
|
||||
# reading whatever is available in the logs at this point. Note, that if we were called with
|
||||
# wait_for_completion == False, we never check the job status.
|
||||
#
|
||||
# If wait_for_completion == TRUE and job is not completed, the initial state is TAILING
|
||||
# If wait_for_completion == FALSE, the initial state is COMPLETE
|
||||
# (doesn't matter if the job really is complete).
|
||||
#
|
||||
# The state table:
|
||||
#
|
||||
# STATE ACTIONS CONDITION NEW STATE
|
||||
# ---------------- ---------------- ----------------- ----------------
|
||||
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
|
||||
# Else TAILING
|
||||
# JOB_COMPLETE Read logs, Pause Any COMPLETE
|
||||
# COMPLETE Read logs, Exit N/A
|
||||
#
|
||||
# Notes:
|
||||
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that
|
||||
# got to Cloudwatch after the job was marked complete.
|
||||
last_describe_job_call = time.time()
|
||||
last_description = description
|
||||
|
||||
while True:
|
||||
time.sleep(check_interval)
|
||||
sec = sec + check_interval
|
||||
|
||||
state, last_description, last_describe_job_call = \
|
||||
self.describe_training_job_with_log(job_name, positions, stream_names,
|
||||
instance_count, state, last_description,
|
||||
last_describe_job_call)
|
||||
if state == LogState.COMPLETE:
|
||||
break
|
||||
|
||||
if max_ingestion_time and sec > max_ingestion_time:
|
||||
# ensure that the job gets killed if the max ingestion time is exceeded
|
||||
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
|
||||
|
||||
if wait_for_completion:
|
||||
status = last_description['TrainingJobStatus']
|
||||
if status in failed_states:
|
||||
reason = last_description.get('FailureReason', '(No reason provided)')
|
||||
raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason))
|
||||
billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
|
||||
* instance_count
|
||||
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
|
|
@ -0,0 +1,239 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
||||
from airflow.typing_compat import Protocol
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ECSProtocol(Protocol):
|
||||
def run_task(self, **kwargs):
|
||||
...
|
||||
|
||||
def get_waiter(self, x: str):
|
||||
...
|
||||
|
||||
def describe_tasks(self, cluster, tasks):
|
||||
...
|
||||
|
||||
def stop_task(self, cluster, task, reason: str):
|
||||
...
|
||||
|
||||
|
||||
class ECSOperator(BaseOperator):
|
||||
"""
|
||||
Execute a task on AWS EC2 Container Service
|
||||
|
||||
:param task_definition: the task definition name on EC2 Container Service
|
||||
:type task_definition: str
|
||||
:param cluster: the cluster name on EC2 Container Service
|
||||
:type cluster: str
|
||||
:param overrides: the same parameter that boto3 will receive (templated):
|
||||
http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task
|
||||
:type overrides: dict
|
||||
:param aws_conn_id: connection id of AWS credentials / region name. If None,
|
||||
credential boto3 strategy will be used
|
||||
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
|
||||
:type aws_conn_id: str
|
||||
:param region_name: region name to use in AWS Hook.
|
||||
Override the region_name in connection (if provided)
|
||||
:type region_name: str
|
||||
:param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
|
||||
:type launch_type: str
|
||||
:param group: the name of the task group associated with the task
|
||||
:type group: str
|
||||
:param placement_constraints: an array of placement constraint objects to use for
|
||||
the task
|
||||
:type placement_constraints: list
|
||||
:param platform_version: the platform version on which your task is running
|
||||
:type platform_version: str
|
||||
:param network_configuration: the network configuration for the task
|
||||
:type network_configuration: dict
|
||||
:param tags: a dictionary of tags in the form of {'tagKey': 'tagValue'}.
|
||||
:type tags: dict
|
||||
:param awslogs_group: the CloudWatch group where your ECS container logs are stored.
|
||||
Only required if you want logs to be shown in the Airflow UI after your job has
|
||||
finished.
|
||||
:type awslogs_group: str
|
||||
:param awslogs_region: the region in which your CloudWatch logs are stored.
|
||||
If None, this is the same as the `region_name` parameter. If that is also None,
|
||||
this is the default AWS region based on your connection settings.
|
||||
:type awslogs_region: str
|
||||
:param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs.
|
||||
This is usually based on some custom name combined with the name of the container.
|
||||
Only required if you want logs to be shown in the Airflow UI after your job has
|
||||
finished.
|
||||
:type awslogs_stream_prefix: str
|
||||
"""
|
||||
|
||||
ui_color = '#f0ede4'
|
||||
client = None # type: Optional[ECSProtocol]
|
||||
arn = None # type: Optional[str]
|
||||
template_fields = ('overrides',)
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, task_definition, cluster, overrides,
|
||||
aws_conn_id=None, region_name=None, launch_type='EC2',
|
||||
group=None, placement_constraints=None, platform_version='LATEST',
|
||||
network_configuration=None, tags=None, awslogs_group=None,
|
||||
awslogs_region=None, awslogs_stream_prefix=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.region_name = region_name
|
||||
self.task_definition = task_definition
|
||||
self.cluster = cluster
|
||||
self.overrides = overrides
|
||||
self.launch_type = launch_type
|
||||
self.group = group
|
||||
self.placement_constraints = placement_constraints
|
||||
self.platform_version = platform_version
|
||||
self.network_configuration = network_configuration
|
||||
|
||||
self.tags = tags
|
||||
self.awslogs_group = awslogs_group
|
||||
self.awslogs_stream_prefix = awslogs_stream_prefix
|
||||
self.awslogs_region = awslogs_region
|
||||
|
||||
if self.awslogs_region is None:
|
||||
self.awslogs_region = region_name
|
||||
|
||||
self.hook = self.get_hook()
|
||||
|
||||
def execute(self, context):
|
||||
self.log.info(
|
||||
'Running ECS Task - Task definition: %s - on cluster %s',
|
||||
self.task_definition, self.cluster
|
||||
)
|
||||
self.log.info('ECSOperator overrides: %s', self.overrides)
|
||||
|
||||
self.client = self.hook.get_client_type(
|
||||
'ecs',
|
||||
region_name=self.region_name
|
||||
)
|
||||
|
||||
run_opts = {
|
||||
'cluster': self.cluster,
|
||||
'taskDefinition': self.task_definition,
|
||||
'overrides': self.overrides,
|
||||
'startedBy': self.owner,
|
||||
'launchType': self.launch_type,
|
||||
}
|
||||
|
||||
if self.launch_type == 'FARGATE':
|
||||
run_opts['platformVersion'] = self.platform_version
|
||||
if self.group is not None:
|
||||
run_opts['group'] = self.group
|
||||
if self.placement_constraints is not None:
|
||||
run_opts['placementConstraints'] = self.placement_constraints
|
||||
if self.network_configuration is not None:
|
||||
run_opts['networkConfiguration'] = self.network_configuration
|
||||
if self.tags is not None:
|
||||
run_opts['tags'] = [{'key': k, 'value': v} for (k, v) in self.tags.items()]
|
||||
|
||||
response = self.client.run_task(**run_opts)
|
||||
|
||||
failures = response['failures']
|
||||
if len(failures) > 0:
|
||||
raise AirflowException(response)
|
||||
self.log.info('ECS Task started: %s', response)
|
||||
|
||||
self.arn = response['tasks'][0]['taskArn']
|
||||
self._wait_for_task_ended()
|
||||
|
||||
self._check_success_task()
|
||||
self.log.info('ECS Task has been successfully executed: %s', response)
|
||||
|
||||
def _wait_for_task_ended(self):
|
||||
waiter = self.client.get_waiter('tasks_stopped')
|
||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
|
||||
waiter.wait(
|
||||
cluster=self.cluster,
|
||||
tasks=[self.arn]
|
||||
)
|
||||
|
||||
def _check_success_task(self):
|
||||
response = self.client.describe_tasks(
|
||||
cluster=self.cluster,
|
||||
tasks=[self.arn]
|
||||
)
|
||||
self.log.info('ECS Task stopped, check status: %s', response)
|
||||
|
||||
# Get logs from CloudWatch if the awslogs log driver was used
|
||||
if self.awslogs_group and self.awslogs_stream_prefix:
|
||||
self.log.info('ECS Task logs output:')
|
||||
task_id = self.arn.split("/")[-1]
|
||||
stream_name = "{}/{}".format(self.awslogs_stream_prefix, task_id)
|
||||
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
|
||||
dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
|
||||
self.log.info("[{}] {}".format(dt.isoformat(), event['message']))
|
||||
|
||||
if len(response.get('failures', [])) > 0:
|
||||
raise AirflowException(response)
|
||||
|
||||
for task in response['tasks']:
|
||||
# This is a `stoppedReason` that indicates a task has not
|
||||
# successfully finished, but there is no other indication of failure
|
||||
# in the response.
|
||||
# See, https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html # noqa E501
|
||||
if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.',
|
||||
task.get('stoppedReason', '')):
|
||||
raise AirflowException(
|
||||
'The task was stopped because the host instance terminated: {}'.
|
||||
format(task.get('stoppedReason', '')))
|
||||
containers = task['containers']
|
||||
for container in containers:
|
||||
if container.get('lastStatus') == 'STOPPED' and \
|
||||
container['exitCode'] != 0:
|
||||
raise AirflowException(
|
||||
'This task is not in success state {}'.format(task))
|
||||
elif container.get('lastStatus') == 'PENDING':
|
||||
raise AirflowException('This task is still pending {}'.format(task))
|
||||
elif 'error' in container.get('reason', '').lower():
|
||||
raise AirflowException(
|
||||
'This containers encounter an error during launching : {}'.
|
||||
format(container.get('reason', '').lower()))
|
||||
|
||||
def get_hook(self):
|
||||
return AwsHook(
|
||||
aws_conn_id=self.aws_conn_id
|
||||
)
|
||||
|
||||
def get_logs_hook(self):
|
||||
return AwsLogsHook(
|
||||
aws_conn_id=self.aws_conn_id,
|
||||
region_name=self.awslogs_region
|
||||
)
|
||||
|
||||
def on_kill(self):
|
||||
response = self.client.stop_task(
|
||||
cluster=self.cluster,
|
||||
task=self.arn,
|
||||
reason='Task killed by the user')
|
||||
self.log.info(response)
|
|
@ -0,0 +1,90 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class EmrAddStepsOperator(BaseOperator):
|
||||
"""
|
||||
An operator that adds steps to an existing EMR job_flow.
|
||||
|
||||
:param job_flow_id: id of the JobFlow to add steps to. (templated)
|
||||
:type job_flow_id: Optional[str]
|
||||
:param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing
|
||||
job_flow_id. will search for id of JobFlow with matching name in one of the states in
|
||||
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
|
||||
:type job_flow_name: Optional[str]
|
||||
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
|
||||
(templated)
|
||||
:type cluster_states: list
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
:param steps: boto3 style steps to be added to the jobflow. (templated)
|
||||
:type steps: list
|
||||
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
|
||||
:type do_xcom_push: bool
|
||||
"""
|
||||
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
job_flow_id=None,
|
||||
job_flow_name=None,
|
||||
cluster_states=None,
|
||||
aws_conn_id='aws_default',
|
||||
steps=None,
|
||||
*args, **kwargs):
|
||||
if kwargs.get('xcom_push') is not None:
|
||||
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
|
||||
if not ((job_flow_id is None) ^ (job_flow_name is None)):
|
||||
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
|
||||
super().__init__(*args, **kwargs)
|
||||
steps = steps or []
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.job_flow_id = job_flow_id
|
||||
self.job_flow_name = job_flow_name
|
||||
self.cluster_states = cluster_states
|
||||
self.steps = steps
|
||||
|
||||
def execute(self, context):
|
||||
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
emr = emr_hook.get_conn()
|
||||
|
||||
job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
|
||||
self.cluster_states)
|
||||
if not job_flow_id:
|
||||
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
|
||||
|
||||
if self.do_xcom_push:
|
||||
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
|
||||
|
||||
self.log.info('Adding steps to %s', job_flow_id)
|
||||
response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=self.steps)
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('Adding steps failed: %s' % response)
|
||||
else:
|
||||
self.log.info('Steps %s added to JobFlow', response['StepIds'])
|
||||
return response['StepIds']
|
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class EmrCreateJobFlowOperator(BaseOperator):
|
||||
"""
|
||||
Creates an EMR JobFlow, reading the config from the EMR connection.
|
||||
A dictionary of JobFlow overrides can be passed that override
|
||||
the config from the connection.
|
||||
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
:param emr_conn_id: emr connection to use
|
||||
:type emr_conn_id: str
|
||||
:param job_flow_overrides: boto3 style arguments to override
|
||||
emr_connection extra. (templated)
|
||||
:type job_flow_overrides: dict
|
||||
"""
|
||||
template_fields = ['job_flow_overrides']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
aws_conn_id='aws_default',
|
||||
emr_conn_id='emr_default',
|
||||
job_flow_overrides=None,
|
||||
region_name=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.emr_conn_id = emr_conn_id
|
||||
if job_flow_overrides is None:
|
||||
job_flow_overrides = {}
|
||||
self.job_flow_overrides = job_flow_overrides
|
||||
self.region_name = region_name
|
||||
|
||||
def execute(self, context):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id,
|
||||
emr_conn_id=self.emr_conn_id,
|
||||
region_name=self.region_name)
|
||||
|
||||
self.log.info(
|
||||
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',
|
||||
self.aws_conn_id, self.emr_conn_id
|
||||
)
|
||||
response = emr.create_job_flow(self.job_flow_overrides)
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('JobFlow creation failed: %s' % response)
|
||||
else:
|
||||
self.log.info('JobFlow with id %s created', response['JobFlowId'])
|
||||
return response['JobFlowId']
|
|
@ -0,0 +1,57 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class EmrTerminateJobFlowOperator(BaseOperator):
|
||||
"""
|
||||
Operator to terminate EMR JobFlows.
|
||||
|
||||
:param job_flow_id: id of the JobFlow to terminate. (templated)
|
||||
:type job_flow_id: str
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
template_fields = ['job_flow_id']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
job_flow_id,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_flow_id = job_flow_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def execute(self, context):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Terminating JobFlow %s', self.job_flow_id)
|
||||
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('JobFlow termination failed: %s' % response)
|
||||
else:
|
||||
self.log.info('JobFlow with id %s terminated', self.job_flow_id)
|
|
@ -0,0 +1,96 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3CopyObjectOperator(BaseOperator):
|
||||
"""
|
||||
Creates a copy of an object that is already stored in S3.
|
||||
|
||||
Note: the S3 connection used here needs to have access to both
|
||||
source and destination bucket/key.
|
||||
|
||||
:param source_bucket_key: The key of the source object. (templated)
|
||||
|
||||
It can be either full s3:// style url or relative path from root level.
|
||||
|
||||
When it's specified as a full s3:// url, please omit source_bucket_name.
|
||||
:type source_bucket_key: str
|
||||
:param dest_bucket_key: The key of the object to copy to. (templated)
|
||||
|
||||
The convention to specify `dest_bucket_key` is the same as `source_bucket_key`.
|
||||
:type dest_bucket_key: str
|
||||
:param source_bucket_name: Name of the S3 bucket where the source object is in. (templated)
|
||||
|
||||
It should be omitted when `source_bucket_key` is provided as a full s3:// url.
|
||||
:type source_bucket_name: str
|
||||
:param dest_bucket_name: Name of the S3 bucket to where the object is copied. (templated)
|
||||
|
||||
It should be omitted when `dest_bucket_key` is provided as a full s3:// url.
|
||||
:type dest_bucket_name: str
|
||||
:param source_version_id: Version ID of the source object (OPTIONAL)
|
||||
:type source_version_id: str
|
||||
:param aws_conn_id: Connection id of the S3 connection to use
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
|
||||
You can provide the following values:
|
||||
|
||||
- False: do not validate SSL certificates. SSL will still be used,
|
||||
but SSL certificates will not be
|
||||
verified.
|
||||
- path/to/cert/bundle.pem: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
|
||||
template_fields = ('source_bucket_key', 'dest_bucket_key',
|
||||
'source_bucket_name', 'dest_bucket_name')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
source_bucket_key,
|
||||
dest_bucket_key,
|
||||
source_bucket_name=None,
|
||||
dest_bucket_name=None,
|
||||
source_version_id=None,
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.source_bucket_key = source_bucket_key
|
||||
self.dest_bucket_key = dest_bucket_key
|
||||
self.source_bucket_name = source_bucket_name
|
||||
self.dest_bucket_name = dest_bucket_name
|
||||
self.source_version_id = source_version_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def execute(self, context):
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key,
|
||||
self.source_bucket_name, self.dest_bucket_name,
|
||||
self.source_version_id)
|
|
@ -0,0 +1,87 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3DeleteObjectsOperator(BaseOperator):
|
||||
"""
|
||||
To enable users to delete single object or multiple objects from
|
||||
a bucket using a single HTTP request.
|
||||
|
||||
Users may specify up to 1000 keys to delete.
|
||||
|
||||
:param bucket: Name of the bucket in which you are going to delete object(s). (templated)
|
||||
:type bucket: str
|
||||
:param keys: The key(s) to delete from S3 bucket. (templated)
|
||||
|
||||
When ``keys`` is a string, it's supposed to be the key name of
|
||||
the single object to delete.
|
||||
|
||||
When ``keys`` is a list, it's supposed to be the list of the
|
||||
keys to delete.
|
||||
|
||||
You may specify up to 1000 keys.
|
||||
:type keys: str or list
|
||||
:param aws_conn_id: Connection id of the S3 connection to use
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used,
|
||||
but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
|
||||
template_fields = ('keys', 'bucket')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
bucket,
|
||||
keys,
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.bucket = bucket
|
||||
self.keys = keys
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def execute(self, context):
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
|
||||
response = s3_hook.delete_objects(bucket=self.bucket, keys=self.keys)
|
||||
|
||||
deleted_keys = [x['Key'] for x in response.get("Deleted", [])]
|
||||
self.log.info("Deleted: %s", deleted_keys)
|
||||
|
||||
if "Errors" in response:
|
||||
errors_keys = [x['Key'] for x in response.get("Errors", [])]
|
||||
raise AirflowException("Errors when deleting: {}".format(errors_keys))
|
|
@ -0,0 +1,170 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, Union
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3FileTransformOperator(BaseOperator):
|
||||
"""
|
||||
Copies data from a source S3 location to a temporary location on the
|
||||
local filesystem. Runs a transformation on this file as specified by
|
||||
the transformation script and uploads the output to a destination S3
|
||||
location.
|
||||
|
||||
The locations of the source and the destination files in the local
|
||||
filesystem is provided as an first and second arguments to the
|
||||
transformation script. The transformation script is expected to read the
|
||||
data from source, transform it and write the output to the local
|
||||
destination file. The operator then takes over control and uploads the
|
||||
local destination file to S3.
|
||||
|
||||
S3 Select is also available to filter the source contents. Users can
|
||||
omit the transformation script if S3 Select expression is specified.
|
||||
|
||||
:param source_s3_key: The key to be retrieved from S3. (templated)
|
||||
:type source_s3_key: str
|
||||
:param dest_s3_key: The key to be written from S3. (templated)
|
||||
:type dest_s3_key: str
|
||||
:param transform_script: location of the executable transformation script
|
||||
:type transform_script: str
|
||||
:param select_expression: S3 Select expression
|
||||
:type select_expression: str
|
||||
:param source_aws_conn_id: source s3 connection
|
||||
:type source_aws_conn_id: str
|
||||
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
|
||||
This is also applicable to ``dest_verify``.
|
||||
:type source_verify: bool or str
|
||||
:param dest_aws_conn_id: destination s3 connection
|
||||
:type dest_aws_conn_id: str
|
||||
:param dest_verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
See: ``source_verify``
|
||||
:type dest_verify: bool or str
|
||||
:param replace: Replace dest S3 key if it already exists
|
||||
:type replace: bool
|
||||
"""
|
||||
|
||||
template_fields = ('source_s3_key', 'dest_s3_key')
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
source_s3_key: str,
|
||||
dest_s3_key: str,
|
||||
transform_script: Optional[str] = None,
|
||||
select_expression=None,
|
||||
source_aws_conn_id: str = 'aws_default',
|
||||
source_verify: Optional[Union[bool, str]] = None,
|
||||
dest_aws_conn_id: str = 'aws_default',
|
||||
dest_verify: Optional[Union[bool, str]] = None,
|
||||
replace: bool = False,
|
||||
*args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.source_s3_key = source_s3_key
|
||||
self.source_aws_conn_id = source_aws_conn_id
|
||||
self.source_verify = source_verify
|
||||
self.dest_s3_key = dest_s3_key
|
||||
self.dest_aws_conn_id = dest_aws_conn_id
|
||||
self.dest_verify = dest_verify
|
||||
self.replace = replace
|
||||
self.transform_script = transform_script
|
||||
self.select_expression = select_expression
|
||||
self.output_encoding = sys.getdefaultencoding()
|
||||
|
||||
def execute(self, context):
|
||||
if self.transform_script is None and self.select_expression is None:
|
||||
raise AirflowException(
|
||||
"Either transform_script or select_expression must be specified")
|
||||
|
||||
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify)
|
||||
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
|
||||
|
||||
self.log.info("Downloading source S3 file %s", self.source_s3_key)
|
||||
if not source_s3.check_for_key(self.source_s3_key):
|
||||
raise AirflowException(
|
||||
"The source key {0} does not exist".format(self.source_s3_key))
|
||||
source_s3_key_object = source_s3.get_key(self.source_s3_key)
|
||||
|
||||
with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest:
|
||||
self.log.info(
|
||||
"Dumping S3 file %s contents to local file %s",
|
||||
self.source_s3_key, f_source.name
|
||||
)
|
||||
|
||||
if self.select_expression is not None:
|
||||
content = source_s3.select_key(
|
||||
key=self.source_s3_key,
|
||||
expression=self.select_expression
|
||||
)
|
||||
f_source.write(content.encode("utf-8"))
|
||||
else:
|
||||
source_s3_key_object.download_fileobj(Fileobj=f_source)
|
||||
f_source.flush()
|
||||
|
||||
if self.transform_script is not None:
|
||||
process = subprocess.Popen(
|
||||
[self.transform_script, f_source.name, f_dest.name],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
close_fds=True
|
||||
)
|
||||
|
||||
self.log.info("Output:")
|
||||
for line in iter(process.stdout.readline, b''):
|
||||
self.log.info(line.decode(self.output_encoding).rstrip())
|
||||
|
||||
process.wait()
|
||||
|
||||
if process.returncode > 0:
|
||||
raise AirflowException(
|
||||
"Transform script failed: {0}".format(process.returncode)
|
||||
)
|
||||
else:
|
||||
self.log.info(
|
||||
"Transform script successful. Output temporarily located at %s",
|
||||
f_dest.name
|
||||
)
|
||||
|
||||
self.log.info("Uploading transformed file to S3")
|
||||
f_dest.flush()
|
||||
dest_s3.load_file(
|
||||
filename=f_dest.name,
|
||||
key=self.dest_s3_key,
|
||||
replace=self.replace
|
||||
)
|
||||
self.log.info("Upload successful")
|
|
@ -0,0 +1,99 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3ListOperator(BaseOperator):
|
||||
"""
|
||||
List all objects from the bucket with the given string prefix in name.
|
||||
|
||||
This operator returns a python list with the name of objects which can be
|
||||
used by `xcom` in the downstream task.
|
||||
|
||||
:param bucket: The S3 bucket where to find the objects. (templated)
|
||||
:type bucket: str
|
||||
:param prefix: Prefix string to filters the objects whose name begin with
|
||||
such prefix. (templated)
|
||||
:type prefix: str
|
||||
:param delimiter: the delimiter marks key hierarchy. (templated)
|
||||
:type delimiter: str
|
||||
:param aws_conn_id: The connection ID to use when connecting to S3 storage.
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
|
||||
|
||||
**Example**:
|
||||
The following operator would list all the files
|
||||
(excluding subfolders) from the S3
|
||||
``customers/2018/04/`` key in the ``data`` bucket. ::
|
||||
|
||||
s3_file = S3ListOperator(
|
||||
task_id='list_3s_files',
|
||||
bucket='data',
|
||||
prefix='customers/2018/04/',
|
||||
delimiter='/',
|
||||
aws_conn_id='aws_customers_conn'
|
||||
)
|
||||
"""
|
||||
template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str]
|
||||
ui_color = '#ffd700'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
bucket,
|
||||
prefix='',
|
||||
delimiter='',
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.bucket = bucket
|
||||
self.prefix = prefix
|
||||
self.delimiter = delimiter
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def execute(self, context):
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
|
||||
self.log.info(
|
||||
'Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)',
|
||||
self.bucket, self.prefix, self.delimiter
|
||||
)
|
||||
|
||||
return hook.list_keys(
|
||||
bucket_name=self.bucket,
|
||||
prefix=self.prefix,
|
||||
delimiter=self.delimiter)
|
|
@ -0,0 +1,101 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
from typing import Iterable
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerBaseOperator(BaseOperator):
|
||||
"""
|
||||
This is the base operator for all SageMaker operators.
|
||||
|
||||
:param config: The configuration necessary to start a training job (templated)
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
|
||||
template_fields = ['config']
|
||||
template_ext = ()
|
||||
ui_color = '#ededed'
|
||||
|
||||
integer_fields = [] # type: Iterable[Iterable[str]]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.config = config
|
||||
self.hook = None
|
||||
|
||||
def parse_integer(self, config, field):
|
||||
if len(field) == 1:
|
||||
if isinstance(config, list):
|
||||
for sub_config in config:
|
||||
self.parse_integer(sub_config, field)
|
||||
return
|
||||
head = field[0]
|
||||
if head in config:
|
||||
config[head] = int(config[head])
|
||||
return
|
||||
|
||||
if isinstance(config, list):
|
||||
for sub_config in config:
|
||||
self.parse_integer(sub_config, field)
|
||||
return
|
||||
|
||||
head, tail = field[0], field[1:]
|
||||
if head in config:
|
||||
self.parse_integer(config[head], tail)
|
||||
return
|
||||
|
||||
def parse_config_integers(self):
|
||||
# Parse the integer fields of training config to integers
|
||||
# in case the config is rendered by Jinja and all fields are str
|
||||
for field in self.integer_fields:
|
||||
self.parse_integer(self.config, field)
|
||||
|
||||
def expand_role(self):
|
||||
pass
|
||||
|
||||
def preprocess_config(self):
|
||||
self.log.info(
|
||||
'Preprocessing the config and doing required s3_operations'
|
||||
)
|
||||
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.hook.configure_s3_resources(self.config)
|
||||
self.parse_config_integers()
|
||||
self.expand_role()
|
||||
|
||||
self.log.info(
|
||||
'After preprocessing the config is:\n {}'.format(
|
||||
json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
|
||||
)
|
||||
|
||||
def execute(self, context):
|
||||
raise NotImplementedError('Please implement execute() in sub class!')
|
|
@ -0,0 +1,150 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerEndpointOperator(SageMakerBaseOperator):
|
||||
|
||||
"""
|
||||
Create a SageMaker endpoint.
|
||||
|
||||
This operator returns The ARN of the endpoint created in Amazon SageMaker
|
||||
|
||||
:param config:
|
||||
The configuration necessary to create an endpoint.
|
||||
|
||||
If you need to create a SageMaker endpoint based on an existed
|
||||
SageMaker model and an existed SageMaker endpoint config::
|
||||
|
||||
config = endpoint_configuration;
|
||||
|
||||
If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
|
||||
|
||||
config = {
|
||||
'Model': model_configuration,
|
||||
'EndpointConfig': endpoint_config_configuration,
|
||||
'Endpoint': endpoint_configuration
|
||||
}
|
||||
|
||||
For details of the configuration parameter of model_configuration see
|
||||
:py:meth:`SageMaker.Client.create_model`
|
||||
|
||||
For details of the configuration parameter of endpoint_config_configuration see
|
||||
:py:meth:`SageMaker.Client.create_endpoint_config`
|
||||
|
||||
For details of the configuration parameter of endpoint_configuration see
|
||||
:py:meth:`SageMaker.Client.create_endpoint`
|
||||
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
|
||||
waits before polling the status of the endpoint creation.
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
|
||||
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
|
||||
:type max_ingestion_time: int
|
||||
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
|
||||
:type operation: str
|
||||
"""
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
operation='create',
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.config = config
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
self.operation = operation.lower()
|
||||
if self.operation not in ['create', 'update']:
|
||||
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
|
||||
self.create_integer_fields()
|
||||
|
||||
def create_integer_fields(self):
|
||||
if 'EndpointConfig' in self.config:
|
||||
self.integer_fields = [
|
||||
['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
|
||||
]
|
||||
|
||||
def expand_role(self):
|
||||
if 'Model' not in self.config:
|
||||
return
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
config = self.config['Model']
|
||||
if 'ExecutionRoleArn' in config:
|
||||
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
model_info = self.config.get('Model')
|
||||
endpoint_config_info = self.config.get('EndpointConfig')
|
||||
endpoint_info = self.config.get('Endpoint', self.config)
|
||||
|
||||
if model_info:
|
||||
self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
|
||||
self.hook.create_model(model_info)
|
||||
|
||||
if endpoint_config_info:
|
||||
self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
|
||||
self.hook.create_endpoint_config(endpoint_config_info)
|
||||
|
||||
if self.operation == 'create':
|
||||
sagemaker_operation = self.hook.create_endpoint
|
||||
log_str = 'Creating'
|
||||
elif self.operation == 'update':
|
||||
sagemaker_operation = self.hook.update_endpoint
|
||||
log_str = 'Updating'
|
||||
else:
|
||||
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
|
||||
|
||||
self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
|
||||
|
||||
response = sagemaker_operation(
|
||||
endpoint_info,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time
|
||||
)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException(
|
||||
'Sagemaker endpoint creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'EndpointConfig': self.hook.describe_endpoint_config(
|
||||
endpoint_info['EndpointConfigName']
|
||||
),
|
||||
'Endpoint': self.hook.describe_endpoint(
|
||||
endpoint_info['EndpointName']
|
||||
)
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
|
||||
|
||||
"""
|
||||
Create a SageMaker endpoint config.
|
||||
|
||||
This operator returns The ARN of the endpoint config created in Amazon SageMaker
|
||||
|
||||
:param config: The configuration necessary to create an endpoint config.
|
||||
|
||||
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
|
||||
integer_fields = [
|
||||
['ProductionVariants', 'InitialInstanceCount']
|
||||
]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.config = config
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
|
||||
response = self.hook.create_endpoint_config(self.config)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException(
|
||||
'Sagemaker endpoint config creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'EndpointConfig': self.hook.describe_endpoint_config(
|
||||
self.config['EndpointConfigName']
|
||||
)
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerModelOperator(SageMakerBaseOperator):
|
||||
|
||||
"""
|
||||
Create a SageMaker model.
|
||||
|
||||
This operator returns The ARN of the model created in Amazon SageMaker
|
||||
|
||||
:param config: The configuration necessary to create a model.
|
||||
|
||||
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.config = config
|
||||
|
||||
def expand_role(self):
|
||||
if 'ExecutionRoleArn' in self.config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
|
||||
response = self.hook.create_model(self.config)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker model creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Model': self.hook.describe_model(
|
||||
self.config['ModelName']
|
||||
)
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerTrainingOperator(SageMakerBaseOperator):
|
||||
"""
|
||||
Initiate a SageMaker training job.
|
||||
|
||||
This operator returns The ARN of the training job created in Amazon SageMaker.
|
||||
|
||||
:param config: The configuration necessary to start a training job (templated).
|
||||
|
||||
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
|
||||
that the operation waits to check the status of the training job.
|
||||
:type wait_for_completion: bool
|
||||
:param print_log: if the operator should print the cloudwatch log during training
|
||||
:type print_log: bool
|
||||
:param check_interval: if wait is set to be true, this is the time interval
|
||||
in seconds which the operator will check the status of the training job
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, the operation fails if the training job
|
||||
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
|
||||
the operation does not timeout.
|
||||
:type max_ingestion_time: int
|
||||
"""
|
||||
|
||||
integer_fields = [
|
||||
['ResourceConfig', 'InstanceCount'],
|
||||
['ResourceConfig', 'VolumeSizeInGB'],
|
||||
['StoppingCondition', 'MaxRuntimeInSeconds']
|
||||
]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
print_log=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.print_log = print_log
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
|
||||
def expand_role(self):
|
||||
if 'RoleArn' in self.config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info('Creating SageMaker Training Job %s.', self.config['TrainingJobName'])
|
||||
|
||||
response = self.hook.create_training_job(
|
||||
self.config,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
print_log=self.print_log,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time
|
||||
)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker Training Job creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Training': self.hook.describe_training_job(
|
||||
self.config['TrainingJobName']
|
||||
)
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerTransformOperator(SageMakerBaseOperator):
|
||||
"""
|
||||
Initiate a SageMaker transform job.
|
||||
|
||||
This operator returns The ARN of the model created in Amazon SageMaker.
|
||||
|
||||
:param config: The configuration necessary to start a transform job (templated).
|
||||
|
||||
If you need to create a SageMaker transform job based on an existed SageMaker model::
|
||||
|
||||
config = transform_config
|
||||
|
||||
If you need to create both SageMaker model and SageMaker Transform job::
|
||||
|
||||
config = {
|
||||
'Model': model_config,
|
||||
'Transform': transform_config
|
||||
}
|
||||
|
||||
For details of the configuration parameter of transform_config see
|
||||
:py:meth:`SageMaker.Client.create_transform_job`
|
||||
|
||||
For details of the configuration parameter of model_config, See:
|
||||
:py:meth:`SageMaker.Client.create_model`
|
||||
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: Set to True to wait until the transform job finishes.
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: If wait is set to True, the time interval, in seconds,
|
||||
that this operation waits to check the status of the transform job.
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, the operation fails
|
||||
if the transform job doesn't finish within max_ingestion_time seconds. If you
|
||||
set this parameter to None, the operation does not timeout.
|
||||
:type max_ingestion_time: int
|
||||
"""
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
self.config = config
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
self.create_integer_fields()
|
||||
|
||||
def create_integer_fields(self):
|
||||
self.integer_fields = [
|
||||
['Transform', 'TransformResources', 'InstanceCount'],
|
||||
['Transform', 'MaxConcurrentTransforms'],
|
||||
['Transform', 'MaxPayloadInMB']
|
||||
]
|
||||
if 'Transform' not in self.config:
|
||||
for field in self.integer_fields:
|
||||
field.pop(0)
|
||||
|
||||
def expand_role(self):
|
||||
if 'Model' not in self.config:
|
||||
return
|
||||
config = self.config['Model']
|
||||
if 'ExecutionRoleArn' in config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
model_config = self.config.get('Model')
|
||||
transform_config = self.config.get('Transform', self.config)
|
||||
|
||||
if model_config:
|
||||
self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
|
||||
self.hook.create_model(model_config)
|
||||
|
||||
self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
|
||||
response = self.hook.create_transform_job(
|
||||
transform_config,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker transform Job creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Model': self.hook.describe_model(
|
||||
transform_config['ModelName']
|
||||
),
|
||||
'Transform': self.hook.describe_transform_job(
|
||||
transform_config['TransformJobName']
|
||||
)
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerTuningOperator(SageMakerBaseOperator):
|
||||
"""
|
||||
Initiate a SageMaker hyperparameter tuning job.
|
||||
|
||||
This operator returns The ARN of the tuning job created in Amazon SageMaker.
|
||||
|
||||
:param config: The configuration necessary to start a tuning job (templated).
|
||||
|
||||
For details of the configuration parameter see
|
||||
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
|
||||
:type config: dict
|
||||
:param aws_conn_id: The AWS connection ID to use.
|
||||
:type aws_conn_id: str
|
||||
:param wait_for_completion: Set to True to wait until the tuning job finishes.
|
||||
:type wait_for_completion: bool
|
||||
:param check_interval: If wait is set to True, the time interval, in seconds,
|
||||
that this operation waits to check the status of the tuning job.
|
||||
:type check_interval: int
|
||||
:param max_ingestion_time: If wait is set to True, the operation fails
|
||||
if the tuning job doesn't finish within max_ingestion_time seconds. If you
|
||||
set this parameter to None, the operation does not timeout.
|
||||
:type max_ingestion_time: int
|
||||
"""
|
||||
|
||||
integer_fields = [
|
||||
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
|
||||
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
|
||||
['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
|
||||
['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
|
||||
['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds']
|
||||
]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
config,
|
||||
wait_for_completion=True,
|
||||
check_interval=30,
|
||||
max_ingestion_time=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(config=config,
|
||||
*args, **kwargs)
|
||||
self.config = config
|
||||
self.wait_for_completion = wait_for_completion
|
||||
self.check_interval = check_interval
|
||||
self.max_ingestion_time = max_ingestion_time
|
||||
|
||||
def expand_role(self):
|
||||
if 'TrainingJobDefinition' in self.config:
|
||||
config = self.config['TrainingJobDefinition']
|
||||
if 'RoleArn' in config:
|
||||
hook = AwsHook(self.aws_conn_id)
|
||||
config['RoleArn'] = hook.expand_role(config['RoleArn'])
|
||||
|
||||
def execute(self, context):
|
||||
self.preprocess_config()
|
||||
|
||||
self.log.info(
|
||||
'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
|
||||
)
|
||||
|
||||
response = self.hook.create_tuning_job(
|
||||
self.config,
|
||||
wait_for_completion=self.wait_for_completion,
|
||||
check_interval=self.check_interval,
|
||||
max_ingestion_time=self.max_ingestion_time
|
||||
)
|
||||
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
|
||||
raise AirflowException('Sagemaker Tuning Job creation failed: %s' % response)
|
||||
else:
|
||||
return {
|
||||
'Tuning': self.hook.describe_tuning_job(
|
||||
self.config['HyperParameterTuningJobName']
|
||||
)
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class EmrBaseSensor(BaseSensorOperator):
|
||||
"""
|
||||
Contains general sensor behavior for EMR.
|
||||
Subclasses should implement get_emr_response() and state_from_response() methods.
|
||||
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE constants.
|
||||
"""
|
||||
ui_color = '#66c3ff'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
response = self.get_emr_response()
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
self.log.info('Bad HTTP response: %s', response)
|
||||
return False
|
||||
|
||||
state = self.state_from_response(response)
|
||||
self.log.info('Job flow currently %s', state)
|
||||
|
||||
if state in self.NON_TERMINAL_STATES:
|
||||
return False
|
||||
|
||||
if state in self.FAILED_STATE:
|
||||
final_message = 'EMR job failed'
|
||||
failure_message = self.failure_message_from_response(response)
|
||||
if failure_message:
|
||||
final_message += ' ' + failure_message
|
||||
raise AirflowException(final_message)
|
||||
|
||||
return True
|
|
@ -0,0 +1,63 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook
|
||||
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class EmrJobFlowSensor(EmrBaseSensor):
|
||||
"""
|
||||
Asks for the state of the JobFlow until it reaches a terminal state.
|
||||
If it fails the sensor errors, failing the task.
|
||||
|
||||
:param job_flow_id: job_flow_id to check the state of
|
||||
:type job_flow_id: str
|
||||
"""
|
||||
|
||||
NON_TERMINAL_STATES = ['STARTING', 'BOOTSTRAPPING', 'RUNNING',
|
||||
'WAITING', 'TERMINATING']
|
||||
FAILED_STATE = ['TERMINATED_WITH_ERRORS']
|
||||
template_fields = ['job_flow_id']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_flow_id,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_flow_id = job_flow_id
|
||||
|
||||
def get_emr_response(self):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Poking cluster %s', self.job_flow_id)
|
||||
return emr.describe_cluster(ClusterId=self.job_flow_id)
|
||||
|
||||
@staticmethod
|
||||
def state_from_response(response):
|
||||
return response['Cluster']['Status']['State']
|
||||
|
||||
@staticmethod
|
||||
def failure_message_from_response(response):
|
||||
state_change_reason = response['Cluster']['Status'].get('StateChangeReason')
|
||||
if state_change_reason:
|
||||
return 'for code: {} with message {}'.format(state_change_reason.get('Code', 'No code'),
|
||||
state_change_reason.get('Message', 'Unknown'))
|
||||
return None
|
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook
|
||||
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class EmrStepSensor(EmrBaseSensor):
|
||||
"""
|
||||
Asks for the state of the step until it reaches a terminal state.
|
||||
If it fails the sensor errors, failing the task.
|
||||
|
||||
:param job_flow_id: job_flow_id which contains the step check the state of
|
||||
:type job_flow_id: str
|
||||
:param step_id: step to check the state of
|
||||
:type step_id: str
|
||||
"""
|
||||
|
||||
NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE', 'CANCEL_PENDING']
|
||||
FAILED_STATE = ['CANCELLED', 'FAILED', 'INTERRUPTED']
|
||||
template_fields = ['job_flow_id', 'step_id']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_flow_id,
|
||||
step_id,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_flow_id = job_flow_id
|
||||
self.step_id = step_id
|
||||
|
||||
def get_emr_response(self):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id)
|
||||
return emr.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
|
||||
|
||||
@staticmethod
|
||||
def state_from_response(response):
|
||||
return response['Step']['Status']['State']
|
||||
|
||||
@staticmethod
|
||||
def failure_message_from_response(response):
|
||||
fail_details = response['Step']['Status'].get('FailureDetails')
|
||||
if fail_details:
|
||||
return 'for reason {} with message {} and log file {}'.format(fail_details.get('Reason'),
|
||||
fail_details.get('Message'),
|
||||
fail_details.get('LogFile'))
|
||||
return None
|
|
@ -0,0 +1,93 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a partition to show up in AWS Glue Catalog.
|
||||
|
||||
:param table_name: The name of the table to wait for, supports the dot
|
||||
notation (my_database.my_table)
|
||||
:type table_name: str
|
||||
:param expression: The partition clause to wait for. This is passed as
|
||||
is to the AWS Glue Catalog API's get_partitions function,
|
||||
and supports SQL like notation as in ``ds='2015-01-01'
|
||||
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
|
||||
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
|
||||
#aws-glue-api-catalog-partitions-GetPartitions
|
||||
:type expression: str
|
||||
:param aws_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type aws_conn_id: str
|
||||
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
|
||||
if not specified.
|
||||
:type region_name: str
|
||||
:param database_name: The name of the catalog database where the partitions reside.
|
||||
:type database_name: str
|
||||
:param poke_interval: Time in seconds that the job should wait in
|
||||
between each tries
|
||||
:type poke_interval: int
|
||||
"""
|
||||
template_fields = ('database_name', 'table_name', 'expression',)
|
||||
ui_color = '#C5CAE9'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
table_name, expression="ds='{{ ds }}'",
|
||||
aws_conn_id='aws_default',
|
||||
region_name=None,
|
||||
database_name='default',
|
||||
poke_interval=60 * 3,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
poke_interval=poke_interval, *args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.region_name = region_name
|
||||
self.table_name = table_name
|
||||
self.expression = expression
|
||||
self.database_name = database_name
|
||||
|
||||
def poke(self, context):
|
||||
"""
|
||||
Checks for existence of the partition in the AWS Glue Catalog table
|
||||
"""
|
||||
if '.' in self.table_name:
|
||||
self.database_name, self.table_name = self.table_name.split('.')
|
||||
self.log.info(
|
||||
'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression
|
||||
)
|
||||
|
||||
return self.get_hook().check_for_partition(
|
||||
self.database_name, self.table_name, self.expression)
|
||||
|
||||
def get_hook(self):
|
||||
"""
|
||||
Gets the AwsGlueCatalogHook
|
||||
"""
|
||||
if not hasattr(self, 'hook'):
|
||||
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
|
||||
self.hook = AwsGlueCatalogHook(
|
||||
aws_conn_id=self.aws_conn_id,
|
||||
region_name=self.region_name)
|
||||
|
||||
return self.hook
|
|
@ -0,0 +1,97 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3KeySensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a key (a file-like instance on S3) to be present in a S3 bucket.
|
||||
S3 being a key/value it does not support folders. The path is just a key
|
||||
a resource.
|
||||
|
||||
:param bucket_key: The key being waited on. Supports full s3:// style url
|
||||
or relative path from root level. When it's specified as a full s3://
|
||||
url, please leave bucket_name as `None`.
|
||||
:type bucket_key: str
|
||||
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
|
||||
is not provided as a full s3:// url.
|
||||
:type bucket_name: str
|
||||
:param wildcard_match: whether the bucket_key should be interpreted as a
|
||||
Unix wildcard pattern
|
||||
:type wildcard_match: bool
|
||||
:param aws_conn_id: a reference to the s3 connection
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
template_fields = ('bucket_key', 'bucket_name')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
bucket_key,
|
||||
bucket_name=None,
|
||||
wildcard_match=False,
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Parse
|
||||
if bucket_name is None:
|
||||
parsed_url = urlparse(bucket_key)
|
||||
if parsed_url.netloc == '':
|
||||
raise AirflowException('Please provide a bucket_name')
|
||||
else:
|
||||
bucket_name = parsed_url.netloc
|
||||
bucket_key = parsed_url.path.lstrip('/')
|
||||
else:
|
||||
parsed_url = urlparse(bucket_key)
|
||||
if parsed_url.scheme != '' or parsed_url.netloc != '':
|
||||
raise AirflowException('If bucket_name is provided, bucket_key' +
|
||||
' should be relative path from root' +
|
||||
' level, rather than a full s3:// url')
|
||||
self.bucket_name = bucket_name
|
||||
self.bucket_key = bucket_key
|
||||
self.wildcard_match = wildcard_match
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def poke(self, context):
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
|
||||
if self.wildcard_match:
|
||||
return hook.check_for_wildcard_key(self.bucket_key,
|
||||
self.bucket_name)
|
||||
return hook.check_for_key(self.bucket_key, self.bucket_name)
|
|
@ -0,0 +1,81 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3PrefixSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a prefix to exist. A prefix is the first part of a key,
|
||||
thus enabling checking of constructs similar to glob airfl* or
|
||||
SQL LIKE 'airfl%'. There is the possibility to precise a delimiter to
|
||||
indicate the hierarchy or keys, meaning that the match will stop at that
|
||||
delimiter. Current code accepts sane delimiters, i.e. characters that
|
||||
are NOT special characters in the Python regex engine.
|
||||
|
||||
:param bucket_name: Name of the S3 bucket
|
||||
:type bucket_name: str
|
||||
:param prefix: The prefix being waited on. Relative path from bucket root level.
|
||||
:type prefix: str
|
||||
:param delimiter: The delimiter intended to show hierarchy.
|
||||
Defaults to '/'.
|
||||
:type delimiter: str
|
||||
:param aws_conn_id: a reference to the s3 connection
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
template_fields = ('prefix', 'bucket_name')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
bucket_name,
|
||||
prefix,
|
||||
delimiter='/',
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Parse
|
||||
self.bucket_name = bucket_name
|
||||
self.prefix = prefix
|
||||
self.delimiter = delimiter
|
||||
self.full_url = "s3://" + bucket_name + '/' + prefix
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def poke(self, context):
|
||||
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name)
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
return hook.check_for_prefix(
|
||||
prefix=self.prefix,
|
||||
delimiter=self.delimiter,
|
||||
bucket_name=self.bucket_name)
|
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerBaseSensor(BaseSensorOperator):
|
||||
"""
|
||||
Contains general sensor behavior for SageMaker.
|
||||
Subclasses should implement get_sagemaker_response()
|
||||
and state_from_response() methods.
|
||||
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
|
||||
"""
|
||||
ui_color = '#ededed'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
response = self.get_sagemaker_response()
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
self.log.info('Bad HTTP response: %s', response)
|
||||
return False
|
||||
|
||||
state = self.state_from_response(response)
|
||||
|
||||
self.log.info('Job currently %s', state)
|
||||
|
||||
if state in self.non_terminal_states():
|
||||
return False
|
||||
|
||||
if state in self.failed_states():
|
||||
failed_reason = self.get_failed_reason_from_response(response)
|
||||
raise AirflowException('Sagemaker job failed for the following reason: %s'
|
||||
% failed_reason)
|
||||
return True
|
||||
|
||||
def non_terminal_states(self):
|
||||
raise NotImplementedError('Please implement non_terminal_states() in subclass')
|
||||
|
||||
def failed_states(self):
|
||||
raise NotImplementedError('Please implement failed_states() in subclass')
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return 'Unknown'
|
||||
|
||||
def state_from_response(self, response):
|
||||
raise NotImplementedError('Please implement state_from_response() in subclass')
|
|
@ -0,0 +1,61 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerEndpointSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the endpoint state until it reaches a terminal state.
|
||||
If it fails the sensor errors, the task fails.
|
||||
|
||||
:param job_name: job_name of the endpoint instance to check the state of
|
||||
:type job_name: str
|
||||
"""
|
||||
|
||||
template_fields = ['endpoint_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
endpoint_name,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.endpoint_non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
|
||||
return sagemaker.describe_endpoint(self.endpoint_name)
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['EndpointStatus']
|
|
@ -0,0 +1,102 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import time
|
||||
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerTrainingSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the training state until it reaches a terminal state.
|
||||
If it fails the sensor errors, failing the task.
|
||||
|
||||
:param job_name: name of the SageMaker training job to check the state of
|
||||
:type job_name: str
|
||||
:param print_log: if the operator should print the cloudwatch log
|
||||
:type print_log: bool
|
||||
"""
|
||||
|
||||
template_fields = ['job_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_name,
|
||||
print_log=True,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_name = job_name
|
||||
self.print_log = print_log
|
||||
self.positions = {}
|
||||
self.stream_names = []
|
||||
self.instance_count = None
|
||||
self.state = None
|
||||
self.last_description = None
|
||||
self.last_describe_job_call = None
|
||||
self.log_resource_inited = False
|
||||
|
||||
def init_log_resource(self, hook):
|
||||
description = hook.describe_training_job(self.job_name)
|
||||
self.instance_count = description['ResourceConfig']['InstanceCount']
|
||||
|
||||
status = description['TrainingJobStatus']
|
||||
job_already_completed = status not in self.non_terminal_states()
|
||||
self.state = LogState.TAILING if not job_already_completed else LogState.COMPLETE
|
||||
self.last_description = description
|
||||
self.last_describe_job_call = time.time()
|
||||
self.log_resource_inited = True
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
if self.print_log:
|
||||
if not self.log_resource_inited:
|
||||
self.init_log_resource(sagemaker_hook)
|
||||
self.state, self.last_description, self.last_describe_job_call = \
|
||||
sagemaker_hook.describe_training_job_with_log(self.job_name,
|
||||
self.positions, self.stream_names,
|
||||
self.instance_count, self.state,
|
||||
self.last_description,
|
||||
self.last_describe_job_call)
|
||||
else:
|
||||
self.last_description = sagemaker_hook.describe_training_job(self.job_name)
|
||||
|
||||
status = self.state_from_response(self.last_description)
|
||||
if status not in self.non_terminal_states() and status not in self.failed_states():
|
||||
billable_time = \
|
||||
(self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
|
||||
self.last_description['ResourceConfig']['InstanceCount']
|
||||
self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)
|
||||
|
||||
return self.last_description
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['TrainingJobStatus']
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerTransformSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the transform state until it reaches a terminal state.
|
||||
The sensor will error if the job errors, throwing a AirflowException
|
||||
containing the failure reason.
|
||||
|
||||
:param job_name: job_name of the transform job instance to check the state of
|
||||
:type job_name: str
|
||||
"""
|
||||
|
||||
template_fields = ['job_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_name,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_name = job_name
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
|
||||
return sagemaker.describe_transform_job(self.job_name)
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['TransformJobStatus']
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SageMakerTuningSensor(SageMakerBaseSensor):
|
||||
"""
|
||||
Asks for the state of the tuning state until it reaches a terminal state.
|
||||
The sensor will error if the job errors, throwing a AirflowException
|
||||
containing the failure reason.
|
||||
|
||||
:param job_name: job_name of the tuning instance to check the state of
|
||||
:type job_name: str
|
||||
"""
|
||||
|
||||
template_fields = ['job_name']
|
||||
template_ext = ()
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
job_name,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.job_name = job_name
|
||||
|
||||
def non_terminal_states(self):
|
||||
return SageMakerHook.non_terminal_states
|
||||
|
||||
def failed_states(self):
|
||||
return SageMakerHook.failed_states
|
||||
|
||||
def get_sagemaker_response(self):
|
||||
sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
|
||||
return sagemaker.describe_tuning_job(self.job_name)
|
||||
|
||||
def get_failed_reason_from_response(self, response):
|
||||
return response['FailureReason']
|
||||
|
||||
def state_from_response(self, response):
|
||||
return response['HyperParameterTuningJobStatus']
|
|
@ -16,82 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`."""
|
||||
|
||||
import warnings
|
||||
|
||||
from urllib.parse import urlparse
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor # noqa
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class S3KeySensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a key (a file-like instance on S3) to be present in a S3 bucket.
|
||||
S3 being a key/value it does not support folders. The path is just a key
|
||||
a resource.
|
||||
|
||||
:param bucket_key: The key being waited on. Supports full s3:// style url
|
||||
or relative path from root level. When it's specified as a full s3://
|
||||
url, please leave bucket_name as `None`.
|
||||
:type bucket_key: str
|
||||
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
|
||||
is not provided as a full s3:// url.
|
||||
:type bucket_name: str
|
||||
:param wildcard_match: whether the bucket_key should be interpreted as a
|
||||
Unix wildcard pattern
|
||||
:type wildcard_match: bool
|
||||
:param aws_conn_id: a reference to the s3 connection
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
template_fields = ('bucket_key', 'bucket_name')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
bucket_key,
|
||||
bucket_name=None,
|
||||
wildcard_match=False,
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Parse
|
||||
if bucket_name is None:
|
||||
parsed_url = urlparse(bucket_key)
|
||||
if parsed_url.netloc == '':
|
||||
raise AirflowException('Please provide a bucket_name')
|
||||
else:
|
||||
bucket_name = parsed_url.netloc
|
||||
bucket_key = parsed_url.path.lstrip('/')
|
||||
else:
|
||||
parsed_url = urlparse(bucket_key)
|
||||
if parsed_url.scheme != '' or parsed_url.netloc != '':
|
||||
raise AirflowException('If bucket_name is provided, bucket_key' +
|
||||
' should be relative path from root' +
|
||||
' level, rather than a full s3:// url')
|
||||
self.bucket_name = bucket_name
|
||||
self.bucket_key = bucket_key
|
||||
self.wildcard_match = wildcard_match
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def poke(self, context):
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
|
||||
if self.wildcard_match:
|
||||
return hook.check_for_wildcard_key(self.bucket_key,
|
||||
self.bucket_name)
|
||||
return hook.check_for_key(self.bucket_key, self.bucket_name)
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -16,66 +16,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`."""
|
||||
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.amazon.aws.sensors.s3_prefix import S3PrefixSensor # noqa
|
||||
|
||||
class S3PrefixSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a prefix to exist. A prefix is the first part of a key,
|
||||
thus enabling checking of constructs similar to glob airfl* or
|
||||
SQL LIKE 'airfl%'. There is the possibility to precise a delimiter to
|
||||
indicate the hierarchy or keys, meaning that the match will stop at that
|
||||
delimiter. Current code accepts sane delimiters, i.e. characters that
|
||||
are NOT special characters in the Python regex engine.
|
||||
|
||||
:param bucket_name: Name of the S3 bucket
|
||||
:type bucket_name: str
|
||||
:param prefix: The prefix being waited on. Relative path from bucket root level.
|
||||
:type prefix: str
|
||||
:param delimiter: The delimiter intended to show hierarchy.
|
||||
Defaults to '/'.
|
||||
:type delimiter: str
|
||||
:param aws_conn_id: a reference to the s3 connection
|
||||
:type aws_conn_id: str
|
||||
:param verify: Whether or not to verify SSL certificates for S3 connection.
|
||||
By default SSL certificates are verified.
|
||||
You can provide the following values:
|
||||
|
||||
- ``False``: do not validate SSL certificates. SSL will still be used
|
||||
(unless use_ssl is False), but SSL certificates will not be
|
||||
verified.
|
||||
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
|
||||
You can specify this argument if you want to use a different
|
||||
CA cert bundle than the one used by botocore.
|
||||
:type verify: bool or str
|
||||
"""
|
||||
template_fields = ('prefix', 'bucket_name')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
bucket_name,
|
||||
prefix,
|
||||
delimiter='/',
|
||||
aws_conn_id='aws_default',
|
||||
verify=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Parse
|
||||
self.bucket_name = bucket_name
|
||||
self.prefix = prefix
|
||||
self.delimiter = delimiter
|
||||
self.full_url = "s3://" + bucket_name + '/' + prefix
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.verify = verify
|
||||
|
||||
def poke(self, context):
|
||||
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name)
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
return hook.check_for_prefix(
|
||||
prefix=self.prefix,
|
||||
delimiter=self.delimiter,
|
||||
bucket_name=self.bucket_name)
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
|
|
@ -277,7 +277,7 @@ Airflow provides operators for many common tasks, including:
|
|||
|
||||
In addition to these basic building blocks, there are many more specific
|
||||
operators: :class:`~airflow.operators.docker_operator.DockerOperator`,
|
||||
:class:`~airflow.providers.apache.hive.operators.hive.HiveOperator`, :class:`~airflow.operators.s3_file_transform_operator.S3FileTransformOperator`,
|
||||
:class:`~airflow.providers.apache.hive.operators.hive.HiveOperator`, :class:`~airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator`,
|
||||
:class:`~airflow.operators.presto_to_mysql.PrestoToMySqlTransfer`,
|
||||
:class:`~airflow.operators.slack_operator.SlackAPIOperator`... you get the idea!
|
||||
|
||||
|
|
|
@ -315,9 +315,9 @@ These integrations allow you to perform various operations within the Amazon Web
|
|||
|
||||
* - `AWS Glue Catalog <https://aws.amazon.com/glue/>`__
|
||||
-
|
||||
- :mod:`airflow.contrib.hooks.aws_glue_catalog_hook`
|
||||
- :mod:`airflow.providers.amazon.aws.hooks.glue_catalog`
|
||||
-
|
||||
- :mod:`airflow.contrib.sensors.aws_glue_catalog_partition_sensor`
|
||||
- :mod:`airflow.providers.amazon.aws.sensors.glue_catalog_partition`
|
||||
|
||||
* - `AWS Lambda <https://aws.amazon.com/lambda/>`__
|
||||
-
|
||||
|
@ -333,7 +333,7 @@ These integrations allow you to perform various operations within the Amazon Web
|
|||
|
||||
* - `Amazon CloudWatch Logs <https://aws.amazon.com/cloudwatch/>`__
|
||||
-
|
||||
- :mod:`airflow.contrib.hooks.aws_logs_hook`
|
||||
- :mod:`airflow.providers.amazon.aws.hooks.logs`
|
||||
-
|
||||
-
|
||||
|
||||
|
@ -346,18 +346,18 @@ These integrations allow you to perform various operations within the Amazon Web
|
|||
* - `Amazon EC2 <https://aws.amazon.com/ec2/>`__
|
||||
-
|
||||
-
|
||||
- :mod:`airflow.contrib.operators.ecs_operator`
|
||||
- :mod:`airflow.providers.amazon.aws.operators.ecs`
|
||||
-
|
||||
|
||||
* - `Amazon EMR <https://aws.amazon.com/emr/>`__
|
||||
-
|
||||
- :mod:`airflow.contrib.hooks.emr_hook`
|
||||
- :mod:`airflow.contrib.operators.emr_add_steps_operator`,
|
||||
:mod:`airflow.contrib.operators.emr_create_job_flow_operator`,
|
||||
:mod:`airflow.contrib.operators.emr_terminate_job_flow_operator`
|
||||
- :mod:`airflow.contrib.sensors.emr_base_sensor`,
|
||||
:mod:`airflow.contrib.sensors.emr_job_flow_sensor`,
|
||||
:mod:`airflow.contrib.sensors.emr_step_sensor`
|
||||
- :mod:`airflow.providers.amazon.aws.hooks.emr`
|
||||
- :mod:`airflow.providers.amazon.aws.operators.emr_add_steps`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.emr_create_job_flow`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.emr_terminate_job_flow`
|
||||
- :mod:`airflow.providers.amazon.aws.sensors.emr_base`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.emr_job_flow`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.emr_step`
|
||||
|
||||
* - `Amazon Kinesis Data Firehose <https://aws.amazon.com/kinesis/data-firehose/>`__
|
||||
-
|
||||
|
@ -373,19 +373,19 @@ These integrations allow you to perform various operations within the Amazon Web
|
|||
|
||||
* - `Amazon SageMaker <https://aws.amazon.com/sagemaker/>`__
|
||||
-
|
||||
- :mod:`airflow.contrib.hooks.sagemaker_hook`
|
||||
- :mod:`airflow.contrib.operators.sagemaker_base_operator`,
|
||||
:mod:`airflow.contrib.operators.sagemaker_endpoint_config_operator`,
|
||||
:mod:`airflow.contrib.operators.sagemaker_endpoint_operator`,
|
||||
:mod:`airflow.contrib.operators.sagemaker_model_operator`,
|
||||
:mod:`airflow.contrib.operators.sagemaker_training_operator`,
|
||||
:mod:`airflow.contrib.operators.sagemaker_transform_operator`,
|
||||
:mod:`airflow.contrib.operators.sagemaker_tuning_operator`
|
||||
- :mod:`airflow.contrib.sensors.sagemaker_base_sensor`,
|
||||
:mod:`airflow.contrib.sensors.sagemaker_endpoint_sensor`,
|
||||
:mod:`airflow.contrib.sensors.sagemaker_training_sensor`,
|
||||
:mod:`airflow.contrib.sensors.sagemaker_transform_sensor`,
|
||||
:mod:`airflow.contrib.sensors.sagemaker_tuning_sensor`
|
||||
- :mod:`airflow.providers.amazon.aws.hooks.sagemaker`
|
||||
- :mod:`airflow.providers.amazon.aws.operators.sagemaker_base`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.sagemaker_endpoint`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.sagemaker_model`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.sagemaker_training`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.sagemaker_transform`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.sagemaker_tuning`
|
||||
- :mod:`airflow.providers.amazon.aws.sensors.sagemaker_base`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_endpoint`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_training`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_transform`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.sagemaker_tuning`
|
||||
|
||||
* - `Amazon Simple Notification Service (SNS) <https://aws.amazon.com/sns/>`__
|
||||
-
|
||||
|
@ -402,12 +402,12 @@ These integrations allow you to perform various operations within the Amazon Web
|
|||
* - `Amazon Simple Storage Service (S3) <https://aws.amazon.com/s3/>`__
|
||||
-
|
||||
- :mod:`airflow.providers.amazon.aws.hooks.s3`
|
||||
- :mod:`airflow.operators.s3_file_transform_operator`,
|
||||
:mod:`airflow.contrib.operators.s3_copy_object_operator`,
|
||||
:mod:`airflow.contrib.operators.s3_delete_objects_operator`,
|
||||
:mod:`airflow.contrib.operators.s3_list_operator`
|
||||
- :mod:`airflow.sensors.s3_key_sensor`,
|
||||
:mod:`airflow.sensors.s3_prefix_sensor`
|
||||
- :mod:`airflow.providers.amazon.aws.operators.s3_file_transform`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.s3_copy_object`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.s3_delete_objects`,
|
||||
:mod:`airflow.providers.amazon.aws.operators.s3_list`
|
||||
- :mod:`airflow.providers.amazon.aws.sensors.s3_key`,
|
||||
:mod:`airflow.providers.amazon.aws.sensors.s3_prefix`
|
||||
|
||||
Transfer operators and hooks
|
||||
''''''''''''''''''''''''''''
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
./airflow/contrib/hooks/datadog_hook.py
|
||||
./airflow/contrib/hooks/dingding_hook.py
|
||||
./airflow/contrib/hooks/discord_webhook_hook.py
|
||||
./airflow/contrib/hooks/emr_hook.py
|
||||
./airflow/providers/amazon/aws/hooks/emr.py
|
||||
./airflow/hooks/filesystem.py
|
||||
./airflow/contrib/hooks/ftp_hook.py
|
||||
./airflow/contrib/hooks/jenkins_hook.py
|
||||
|
@ -19,7 +19,7 @@
|
|||
./airflow/contrib/hooks/opsgenie_alert_hook.py
|
||||
./airflow/providers/apache/pinot/hooks/pinot.py
|
||||
./airflow/contrib/hooks/qubole_check_hook.py
|
||||
./airflow/contrib/hooks/sagemaker_hook.py
|
||||
./airflow/providers/amazon/aws/hooks/sagemaker.py
|
||||
./airflow/contrib/hooks/segment_hook.py
|
||||
./airflow/contrib/hooks/slack_webhook_hook.py
|
||||
./airflow/contrib/hooks/snowflake_hook.py
|
||||
|
@ -36,10 +36,10 @@
|
|||
./airflow/contrib/operators/dingding_operator.py
|
||||
./airflow/contrib/operators/discord_webhook_operator.py
|
||||
./airflow/providers/apache/druid/operators/druid.py
|
||||
./airflow/contrib/operators/ecs_operator.py
|
||||
./airflow/contrib/operators/emr_add_steps_operator.py
|
||||
./airflow/contrib/operators/emr_create_job_flow_operator.py
|
||||
./airflow/contrib/operators/emr_terminate_job_flow_operator.py
|
||||
./airflow/providers/amazon/aws/operators/ecs.py
|
||||
./airflow/providers/amazon/aws/operators/emr_add_steps.py
|
||||
./airflow/providers/amazon/aws/operators/emr_create_job_flow.py
|
||||
./airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
|
||||
./airflow/contrib/operators/file_to_wasb.py
|
||||
./airflow/contrib/operators/grpc_operator.py
|
||||
./airflow/contrib/operators/jenkins_job_trigger_operator.py
|
||||
|
@ -50,18 +50,18 @@
|
|||
./airflow/contrib/operators/oracle_to_oracle_transfer.py
|
||||
./airflow/contrib/operators/qubole_check_operator.py
|
||||
./airflow/contrib/operators/redis_publish_operator.py
|
||||
./airflow/contrib/operators/s3_copy_object_operator.py
|
||||
./airflow/contrib/operators/s3_delete_objects_operator.py
|
||||
./airflow/contrib/operators/s3_list_operator.py
|
||||
./airflow/providers/amazon/aws/operators/s3_copy_object.py
|
||||
./airflow/providers/amazon/aws/operators/s3_delete_objects.py
|
||||
./airflow/providers/amazon/aws/operators/s3_list.py
|
||||
./airflow/contrib/operators/s3_to_gcs_operator.py
|
||||
./airflow/contrib/operators/s3_to_sftp_operator.py
|
||||
./airflow/contrib/operators/sagemaker_base_operator.py
|
||||
./airflow/contrib/operators/sagemaker_endpoint_config_operator.py
|
||||
./airflow/contrib/operators/sagemaker_endpoint_operator.py
|
||||
./airflow/contrib/operators/sagemaker_model_operator.py
|
||||
./airflow/contrib/operators/sagemaker_training_operator.py
|
||||
./airflow/contrib/operators/sagemaker_transform_operator.py
|
||||
./airflow/contrib/operators/sagemaker_tuning_operator.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_base.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_model.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_training.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_transform.py
|
||||
./airflow/providers/amazon/aws/operators/sagemaker_tuning.py
|
||||
./airflow/contrib/operators/segment_track_event_operator.py
|
||||
./airflow/contrib/operators/sftp_to_s3_operator.py
|
||||
./airflow/contrib/operators/slack_webhook_operator.py
|
||||
|
@ -75,14 +75,14 @@
|
|||
./airflow/contrib/operators/vertica_to_mysql.py
|
||||
./airflow/providers/microsoft/azure/operators/wasb_delete_blob.py
|
||||
./airflow/contrib/operators/winrm_operator.py
|
||||
./airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
|
||||
./airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
|
||||
./airflow/providers/microsoft/azure/sensors/azure_cosmos.py
|
||||
./airflow/contrib/sensors/bash_sensor.py
|
||||
./airflow/contrib/sensors/celery_queue_sensor.py
|
||||
./airflow/contrib/sensors/datadog_sensor.py
|
||||
./airflow/contrib/sensors/emr_base_sensor.py
|
||||
./airflow/contrib/sensors/emr_job_flow_sensor.py
|
||||
./airflow/contrib/sensors/emr_step_sensor.py
|
||||
./airflow/providers/amazon/aws/sensors/emr_base.py
|
||||
./airflow/providers/amazon/aws/sensors/emr_job_flow.py
|
||||
./airflow/providers/amazon/aws/sensors/emr_step.py
|
||||
./airflow/sensors/filesystem.py
|
||||
./airflow/contrib/sensors/ftp_sensor.py
|
||||
./airflow/providers/apache/hdfs/sensors/hdfs.py
|
||||
|
@ -92,11 +92,11 @@
|
|||
./airflow/contrib/sensors/qubole_sensor.py
|
||||
./airflow/contrib/sensors/redis_key_sensor.py
|
||||
./airflow/contrib/sensors/redis_pub_sub_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_base_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_endpoint_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_training_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_transform_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_tuning_sensor.py
|
||||
./airflow/providers/amazon/aws/sensors/sagemaker_base.py
|
||||
./airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
|
||||
./airflow/providers/amazon/aws/sensors/sagemaker_training.py
|
||||
./airflow/providers/amazon/aws/sensors/sagemaker_transform.py
|
||||
./airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
|
||||
./airflow/providers/microsoft/azure/sensors/wasb.py
|
||||
./airflow/sensors/weekday_sensor.py
|
||||
./airflow/hooks/dbapi_hook.py
|
||||
|
@ -160,7 +160,7 @@
|
|||
./airflow/operators/presto_check_operator.py
|
||||
./airflow/operators/presto_to_mysql.py
|
||||
./airflow/operators/python_operator.py
|
||||
./airflow/operators/s3_file_transform_operator.py
|
||||
./airflow/providers/amazon/aws/operators/s3_file_transform.py
|
||||
./airflow/operators/s3_to_redshift_operator.py
|
||||
./airflow/operators/slack_operator.py
|
||||
./airflow/operators/sqlite_operator.py
|
||||
|
@ -175,8 +175,8 @@
|
|||
./airflow/sensors/http_sensor.py
|
||||
./airflow/providers/apache/hive/sensors/metastore_partition.py
|
||||
./airflow/providers/apache/hive/sensors/named_hive_partition.py
|
||||
./airflow/sensors/s3_key_sensor.py
|
||||
./airflow/sensors/s3_prefix_sensor.py
|
||||
./airflow/providers/amazon/aws/sensors/s3_key.py
|
||||
./airflow/providers/amazon/aws/sensors/s3_prefix.py
|
||||
./airflow/sensors/sql_sensor.py
|
||||
./airflow/sensors/time_delta_sensor.py
|
||||
./airflow/sensors/time_sensor.py
|
||||
|
|
|
@ -53,7 +53,7 @@ class TestS3ToGoogleCloudStorageOperator(unittest.TestCase):
|
|||
self.assertEqual(operator.dest_gcs, GCS_PATH_PREFIX)
|
||||
|
||||
@mock.patch('airflow.contrib.operators.s3_to_gcs_operator.S3Hook')
|
||||
@mock.patch('airflow.contrib.operators.s3_list_operator.S3Hook')
|
||||
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
|
||||
@mock.patch(
|
||||
'airflow.contrib.operators.s3_to_gcs_operator.GCSHook')
|
||||
def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
|
||||
|
@ -88,7 +88,7 @@ class TestS3ToGoogleCloudStorageOperator(unittest.TestCase):
|
|||
self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
|
||||
|
||||
@mock.patch('airflow.contrib.operators.s3_to_gcs_operator.S3Hook')
|
||||
@mock.patch('airflow.contrib.operators.s3_list_operator.S3Hook')
|
||||
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
|
||||
@mock.patch(
|
||||
'airflow.contrib.operators.s3_to_gcs_operator.GCSHook')
|
||||
def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
|
||||
|
|
|
@ -22,7 +22,7 @@ import unittest
|
|||
|
||||
import boto3
|
||||
|
||||
from airflow.contrib.hooks.emr_hook import EmrHook
|
||||
from airflow.providers.amazon.aws.hooks.emr import EmrHook
|
||||
|
||||
try:
|
||||
from moto import mock_emr
|
|
@ -21,7 +21,7 @@ import unittest
|
|||
import boto3
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
|
||||
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
|
||||
|
||||
try:
|
||||
from moto import mock_glue
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
|
||||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
||||
|
||||
try:
|
||||
from moto import mock_logs
|
|
@ -25,12 +25,12 @@ from datetime import datetime
|
|||
import mock
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
|
||||
from airflow.contrib.hooks.sagemaker_hook import (
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import (
|
||||
LogState, SageMakerHook, secondary_training_status_changed, secondary_training_status_message,
|
||||
)
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
|
||||
role = 'arn:aws:iam:role/test-role'
|
||||
|
|
@ -25,8 +25,8 @@ from copy import deepcopy
|
|||
import mock
|
||||
from parameterized import parameterized
|
||||
|
||||
from airflow.contrib.operators.ecs_operator import ECSOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
|
||||
|
||||
RESPONSE_WITHOUT_FAILURES = {
|
||||
"failures": [],
|
||||
|
@ -52,7 +52,7 @@ RESPONSE_WITHOUT_FAILURES = {
|
|||
|
||||
class TestECSOperator(unittest.TestCase):
|
||||
|
||||
@mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
|
||||
@mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsHook')
|
||||
def setUp(self, aws_hook_mock):
|
||||
self.aws_hook_mock = aws_hook_mock
|
||||
self.ecs_operator_args = {
|
||||
|
@ -96,7 +96,7 @@ class TestECSOperator(unittest.TestCase):
|
|||
])
|
||||
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
|
||||
@mock.patch.object(ECSOperator, '_check_success_task')
|
||||
@mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
|
||||
@mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsHook')
|
||||
def test_execute_without_failures(self, launch_type, tags, aws_hook_mock,
|
||||
check_mock, wait_mock):
|
||||
client_mock = aws_hook_mock.return_value.get_client_type.return_value
|
|
@ -22,9 +22,9 @@ from datetime import timedelta
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import TaskInstance
|
||||
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator
|
||||
from airflow.utils import timezone
|
||||
|
||||
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
|
||||
|
@ -111,7 +111,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
|
|||
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
|
||||
|
||||
with patch('boto3.session.Session', self.boto3_session_mock):
|
||||
with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
|
||||
with patch('airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name') \
|
||||
as mock_get_cluster_id_by_name:
|
||||
mock_get_cluster_id_by_name.return_value = expected_job_flow_id
|
||||
|
||||
|
@ -132,7 +132,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
|
|||
def test_init_with_nonexistent_cluster_name(self):
|
||||
cluster_name = 'test_cluster'
|
||||
|
||||
with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
|
||||
with patch('airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name') \
|
||||
as mock_get_cluster_id_by_name:
|
||||
mock_get_cluster_id_by_name.return_value = None
|
||||
|
|
@ -23,8 +23,8 @@ from datetime import timedelta
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
|
||||
from airflow.models import TaskInstance
|
||||
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
|
||||
from airflow.utils import timezone
|
||||
|
||||
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
|
|
@ -20,7 +20,7 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator
|
||||
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator
|
||||
|
||||
TERMINATE_SUCCESS_RETURN = {
|
||||
'ResponseMetadata': {
|
|
@ -23,7 +23,7 @@ import unittest
|
|||
import boto3
|
||||
from moto import mock_s3
|
||||
|
||||
from airflow.contrib.operators.s3_copy_object_operator import S3CopyObjectOperator
|
||||
from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator
|
||||
|
||||
|
||||
class TestS3CopyObjectOperator(unittest.TestCase):
|
|
@ -23,7 +23,7 @@ import unittest
|
|||
import boto3
|
||||
from moto import mock_s3
|
||||
|
||||
from airflow.contrib.operators.s3_delete_objects_operator import S3DeleteObjectsOperator
|
||||
from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator
|
||||
|
||||
|
||||
class TestS3DeleteObjectsOperator(unittest.TestCase):
|
|
@ -31,7 +31,7 @@ import boto3
|
|||
from moto import mock_s3
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.operators.s3_file_transform_operator import S3FileTransformOperator
|
||||
from airflow.providers.amazon.aws.operators.s3_file_transform import S3FileTransformOperator
|
||||
|
||||
|
||||
class TestS3FileTransformOperator(unittest.TestCase):
|
|
@ -21,7 +21,7 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.operators.s3_list_operator import S3ListOperator
|
||||
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator
|
||||
|
||||
TASK_ID = 'test-s3-list-operator'
|
||||
BUCKET = 'test-bucket'
|
||||
|
@ -31,7 +31,7 @@ MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
|
|||
|
||||
|
||||
class TestS3ListOperator(unittest.TestCase):
|
||||
@mock.patch('airflow.contrib.operators.s3_list_operator.S3Hook')
|
||||
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
|
||||
def test_execute(self, mock_hook):
|
||||
|
||||
mock_hook.return_value.list_keys.return_value = MOCK_FILES
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
|
||||
|
||||
config = {
|
||||
'key1': '1',
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.operators.sagemaker_endpoint_operator import SageMakerEndpointOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator
|
||||
|
||||
role = 'arn:aws:iam:role/test-role'
|
||||
bucket = 'test-bucket'
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.operators.sagemaker_endpoint_config_operator import SageMakerEndpointConfigOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import SageMakerEndpointConfigOperator
|
||||
|
||||
model_name = 'test-model-name'
|
||||
config_name = 'test-config-name'
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.operators.sagemaker_model_operator import SageMakerModelOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator
|
||||
|
||||
role = 'arn:aws:iam:role/test-role'
|
||||
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator
|
||||
|
||||
role = 'arn:aws:iam:role/test-role'
|
||||
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.operators.sagemaker_transform_operator import SageMakerTransformOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator
|
||||
|
||||
role = 'arn:aws:iam:role/test-role'
|
||||
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.operators.sagemaker_tuning_operator import SageMakerTuningOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator
|
||||
|
||||
role = 'arn:aws:iam:role/test-role'
|
||||
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
|
||||
|
||||
|
||||
class TestEmrBaseSensor(unittest.TestCase):
|
|
@ -24,7 +24,7 @@ from unittest.mock import MagicMock, patch
|
|||
from dateutil.tz import tzlocal
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor
|
||||
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor
|
||||
|
||||
DESCRIBE_CLUSTER_RUNNING_RETURN = {
|
||||
'Cluster': {
|
|
@ -24,7 +24,7 @@ from unittest.mock import MagicMock, patch
|
|||
from dateutil.tz import tzlocal
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor
|
||||
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor
|
||||
|
||||
DESCRIBE_JOB_STEP_RUNNING_RETURN = {
|
||||
'ResponseMetadata': {
|
|
@ -21,8 +21,8 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
|
||||
from airflow.contrib.sensors.aws_glue_catalog_partition_sensor import AwsGlueCatalogPartitionSensor
|
||||
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
|
||||
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import AwsGlueCatalogPartitionSensor
|
||||
|
||||
try:
|
||||
from moto import mock_glue
|
|
@ -23,7 +23,7 @@ from unittest import mock
|
|||
from parameterized import parameterized
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.sensors.s3_key_sensor import S3KeySensor
|
||||
from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor
|
||||
|
||||
|
||||
class TestS3KeySensor(unittest.TestCase):
|
|
@ -20,7 +20,7 @@
|
|||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from airflow.sensors.s3_prefix_sensor import S3PrefixSensor
|
||||
from airflow.providers.amazon.aws.sensors.s3_prefix import S3PrefixSensor
|
||||
|
||||
|
||||
class TestS3PrefixSensor(unittest.TestCase):
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
|
||||
|
||||
|
||||
class TestSagemakerBaseSensor(unittest.TestCase):
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_endpoint_sensor import SageMakerEndpointSensor
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor
|
||||
|
||||
DESCRIBE_ENDPOINT_CREATING_RESPONSE = {
|
||||
'EndpointStatus': 'Creating',
|
|
@ -22,10 +22,10 @@ from datetime import datetime
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
|
||||
from airflow.contrib.hooks.sagemaker_hook import LogState, SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_training_sensor import SageMakerTrainingSensor
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_training import SageMakerTrainingSensor
|
||||
|
||||
DESCRIBE_TRAINING_COMPELETED_RESPONSE = {
|
||||
'TrainingJobStatus': 'Completed',
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_transform_sensor import SageMakerTransformSensor
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor
|
||||
|
||||
DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE = {
|
||||
'TransformJobStatus': 'InProgress',
|
|
@ -21,9 +21,9 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
|
||||
from airflow.contrib.sensors.sagemaker_tuning_sensor import SageMakerTuningSensor
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
||||
from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor
|
||||
|
||||
DESCRIBE_TUNING_INPROGRESS_RESPONSE = {
|
||||
'HyperParameterTuningJobStatus': 'InProgress',
|
|
@ -240,7 +240,22 @@ HOOK = [
|
|||
'airflow.providers.microsoft.azure.hooks.wasb.WasbHook',
|
||||
'airflow.contrib.hooks.wasb_hook.WasbHook',
|
||||
),
|
||||
|
||||
(
|
||||
'airflow.providers.amazon.aws.hooks.glue_catalog.AwsGlueCatalogHook',
|
||||
'airflow.contrib.hooks.aws_glue_catalog_hook.AwsGlueCatalogHook',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.hooks.logs.AwsLogsHook',
|
||||
'airflow.contrib.hooks.aws_logs_hook.AwsLogsHook',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.hooks.emr.EmrHook',
|
||||
'airflow.contrib.hooks.emr_hook.EmrHook',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook',
|
||||
'airflow.contrib.hooks.sagemaker_hook.SageMakerHook',
|
||||
),
|
||||
]
|
||||
|
||||
OPERATOR = [
|
||||
|
@ -911,6 +926,62 @@ OPERATOR = [
|
|||
'airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbDeleteBlobOperator',
|
||||
'airflow.contrib.operators.wasb_delete_blob_operator.WasbDeleteBlobOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.ecs.ECSOperator',
|
||||
'airflow.contrib.operators.ecs_operator.ECSOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.emr_add_steps.EmrAddStepsOperator',
|
||||
'airflow.contrib.operators.emr_add_steps_operator.EmrAddStepsOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrCreateJobFlowOperator',
|
||||
'airflow.contrib.operators.emr_create_job_flow_operator.EmrCreateJobFlowOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.emr_terminate_job_flow.EmrTerminateJobFlowOperator',
|
||||
'airflow.contrib.operators.emr_terminate_job_flow_operator.EmrTerminateJobFlowOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.s3_copy_object.S3CopyObjectOperator',
|
||||
'airflow.contrib.operators.s3_copy_object_operator.S3CopyObjectOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.s3_delete_objects.S3DeleteObjectsOperator',
|
||||
'airflow.contrib.operators.s3_delete_objects_operator.S3DeleteObjectsOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.s3_list.S3ListOperator',
|
||||
'airflow.contrib.operators.s3_list_operator.S3ListOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_base.SageMakerBaseOperator',
|
||||
'airflow.contrib.operators.sagemaker_base_operator.SageMakerBaseOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_endpoint_config.SageMakerEndpointConfigOperator',
|
||||
'airflow.contrib.operators.sagemaker_endpoint_config_operator.SageMakerEndpointConfigOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_endpoint.SageMakerEndpointOperator',
|
||||
'airflow.contrib.operators.sagemaker_endpoint_operator.SageMakerEndpointOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_model.SageMakerModelOperator',
|
||||
'airflow.contrib.operators.sagemaker_model_operator.SageMakerModelOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_training.SageMakerTrainingOperator',
|
||||
'airflow.contrib.operators.sagemaker_training_operator.SageMakerTrainingOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_transform.SageMakerTransformOperator',
|
||||
'airflow.contrib.operators.sagemaker_transform_operator.SageMakerTransformOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.sagemaker_tuning.SageMakerTuningOperator',
|
||||
'airflow.contrib.operators.sagemaker_tuning_operator.SageMakerTuningOperator',
|
||||
),
|
||||
]
|
||||
|
||||
SENSOR = [
|
||||
|
@ -1001,6 +1072,50 @@ SENSOR = [
|
|||
'airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor',
|
||||
'airflow.contrib.sensors.wasb_sensor.WasbPrefixSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.glue_catalog_partition.AwsGlueCatalogPartitionSensor',
|
||||
'airflow.contrib.sensors.aws_glue_catalog_partition_sensor.AwsGlueCatalogPartitionSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.emr_base.EmrBaseSensor',
|
||||
'airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.emr_job_flow.EmrJobFlowSensor',
|
||||
'airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.emr_step.EmrStepSensor',
|
||||
'airflow.contrib.sensors.emr_step_sensor.EmrStepSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.sagemaker_base.SageMakerBaseSensor',
|
||||
'airflow.contrib.sensors.sagemaker_base_sensor.SageMakerBaseSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.sagemaker_endpoint.SageMakerEndpointSensor',
|
||||
'airflow.contrib.sensors.sagemaker_endpoint_sensor.SageMakerEndpointSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.sagemaker_transform.SageMakerTransformSensor',
|
||||
'airflow.contrib.sensors.sagemaker_transform_sensor.SageMakerTransformSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.sagemaker_tuning.SageMakerTuningSensor',
|
||||
'airflow.contrib.sensors.sagemaker_tuning_sensor.SageMakerTuningSensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator',
|
||||
'airflow.operators.s3_file_transform_operator.S3FileTransformOperator',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.s3_key.S3KeySensor',
|
||||
'airflow.sensors.s3_key_sensor.S3KeySensor',
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.sensors.s3_prefix.S3PrefixSensor',
|
||||
'airflow.sensors.s3_prefix_sensor.S3PrefixSensor',
|
||||
),
|
||||
]
|
||||
|
||||
PROTOCOLS = [
|
||||
|
@ -1008,6 +1123,10 @@ PROTOCOLS = [
|
|||
"airflow.providers.amazon.aws.hooks.batch_client.AwsBatchProtocol",
|
||||
"airflow.contrib.operators.awsbatch_operator.BatchProtocol",
|
||||
),
|
||||
(
|
||||
'airflow.providers.amazon.aws.operators.ecs.ECSProtocol',
|
||||
'airflow.contrib.operators.ecs_operator.ECSProtocol',
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче