Check for same task instead of Equality to detect Duplicate Tasks (#8828)

This commit is contained in:
Kaxil Naik 2020-05-16 11:21:12 +01:00 коммит произвёл GitHub
Родитель f4edd90a94
Коммит 15273f0ea0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
17 изменённых файлов: 78 добавлений и 82 удалений

Просмотреть файл

@ -81,7 +81,7 @@ with models.DAG(
)
create_tag_template_field_result2 = BashOperator(
task_id="create_tag_template_field_result", bash_command="echo create_tag_template_field_result"
task_id="create_tag_template_field_result2", bash_command="echo create_tag_template_field_result"
)
# Delete

Просмотреть файл

@ -36,7 +36,7 @@ from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.exceptions import AirflowException
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.base import Operator
from airflow.models.pool import Pool
@ -600,9 +600,8 @@ class BaseOperator(Operator, LoggingMixin):
"The DAG assigned to {} can not be changed.".format(self))
elif self.task_id not in dag.task_dict:
dag.add_task(self)
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] != self:
raise DuplicateTaskIdFound(
"Task id '{}' has already been added to the DAG".format(self.task_id))
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
dag.add_task(self)
self._dag = dag # pylint: disable=attribute-defined-outside-init

Просмотреть файл

@ -1337,7 +1337,7 @@ class DAG(BaseDag, LoggingMixin):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)
if task.task_id in self.task_dict and self.task_dict[task.task_id] != task:
if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
raise DuplicateTaskIdFound(
"Task id '{}' has already been added to the DAG".format(task.task_id))
else:

Просмотреть файл

@ -181,7 +181,7 @@ with models.DAG("example_gcp_datacatalog", default_args=default_args, schedule_i
# [START howto_operator_gcp_datacatalog_create_tag_template_field_result2]
create_tag_template_field_result2 = BashOperator(
task_id="create_tag_template_field_result",
task_id="create_tag_template_field_result2",
bash_command="echo \"{{ task_instance.xcom_pull('create_tag_template_field') }}\"",
)
# [END howto_operator_gcp_datacatalog_create_tag_template_field_result2]

Просмотреть файл

@ -100,7 +100,7 @@ with models.DAG(
# [START howto_operator_gcs_to_gcs_delimiter]
copy_files_with_delimiter = GCSToGCSOperator(
task_id="copy_files_with_wildcard",
task_id="copy_files_with_delimiter",
source_bucket=BUCKET_1_SRC,
source_object="data/",
destination_bucket=BUCKET_1_DST,

Просмотреть файл

@ -979,14 +979,6 @@ class TestDag(unittest.TestCase):
self.assertEqual(dag.task_dict, {op1.task_id: op1})
# Also verify that DAGs with duplicate task_ids don't raise errors
with DAG("test_dag_1", start_date=DEFAULT_DATE) as dag1:
op3 = DummyOperator(task_id="t3")
op4 = BashOperator(task_id="t4", bash_command="sleep 1")
op3 >> op4
self.assertEqual(dag1.task_dict, {op3.task_id: op3, op4.task_id: op4})
def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
with self.assertRaisesRegex(
@ -994,19 +986,11 @@ class TestDag(unittest.TestCase):
):
dag = DAG("test_dag", start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id="t1", dag=dag)
op2 = BashOperator(task_id="t1", bash_command="sleep 1", dag=dag)
op2 = DummyOperator(task_id="t1", dag=dag)
op1 >> op2
self.assertEqual(dag.task_dict, {op1.task_id: op1})
# Also verify that DAGs with duplicate task_ids don't raise errors
dag1 = DAG("test_dag_1", start_date=DEFAULT_DATE)
op3 = DummyOperator(task_id="t3", dag=dag1)
op4 = DummyOperator(task_id="t4", dag=dag1)
op3 >> op4
self.assertEqual(dag1.task_dict, {op3.task_id: op3, op4.task_id: op4})
def test_duplicate_task_ids_for_same_task_is_allowed(self):
"""Verify that same tasks with Duplicate task_id do not raise error"""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:

Просмотреть файл

@ -373,20 +373,20 @@ class TestTaskInstance(unittest.TestCase):
"""
test that updating the executor_config propogates to the TaskInstance DB
"""
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
executor_config={'foo': 'bar'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
with models.DAG(dag_id='test_run_pooling_task') as dag:
task = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
executor_config={'foo': 'bar'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())
ti.run(session=session)
tis = dag.get_task_instances()
self.assertEqual({'foo': 'bar'}, tis[0].executor_config)
task2 = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
executor_config={'bar': 'baz'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
with models.DAG(dag_id='test_run_pooling_task') as dag:
task2 = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
executor_config={'bar': 'baz'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task2, execution_date=timezone.utcnow())

Просмотреть файл

@ -133,7 +133,7 @@ class TestS3ToSFTPOperator(unittest.TestCase):
def delete_remote_resource(self):
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_check_file",
task_id="test_rm_file",
ssh_hook=self.hook,
command="rm {0}".format(self.sftp_path),
do_xcom_push=True,

Просмотреть файл

@ -131,16 +131,18 @@ class TestCreateEvaluateOps(unittest.TestCase):
self.assertEqual('err=0.9', result)
def test_failures(self):
dag = DAG(
'test_dag',
default_args={
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
'project_id': 'test-project',
'region': 'us-east1',
},
schedule_interval='@daily')
def create_test_dag(dag_id):
dag = DAG(
dag_id,
default_args={
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
'project_id': 'test-project',
'region': 'us-east1',
},
schedule_interval='@daily')
return dag
input_with_model = self.INPUT_MISSING_ORIGIN.copy()
other_params_but_models = {
@ -151,26 +153,30 @@ class TestCreateEvaluateOps(unittest.TestCase):
'prediction_path': input_with_model['outputPath'],
'metric_fn_and_keys': (self.metric_fn, ['err']),
'validate_fn': (lambda x: 'err=%.1f' % x['err']),
'dag': dag,
}
with self.assertRaisesRegex(AirflowException, 'Missing model origin'):
mlengine_operator_utils.create_evaluate_ops(**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_1'), **other_params_but_models)
with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', model_name='cde',
**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_2'), model_uri='abc', model_name='cde',
**other_params_but_models)
with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', version_name='vvv',
**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_3'), model_uri='abc', version_name='vvv',
**other_params_but_models)
with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'):
params = other_params_but_models.copy()
params['metric_fn_and_keys'] = (None, ['abc'])
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params)
with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'):
params = other_params_but_models.copy()
params['validate_fn'] = None
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params)

Просмотреть файл

@ -205,7 +205,7 @@ class TestGCSUploadSessionCompleteSensor(TestCase):
self.dag = dag
self.sensor = GCSUploadSessionCompleteSensor(
task_id='sensor',
task_id='sensor_1',
bucket='test-bucket',
prefix='test-prefix/path',
inactivity_period=12,
@ -227,7 +227,7 @@ class TestGCSUploadSessionCompleteSensor(TestCase):
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_files_deleted_between_pokes_allow_delete(self):
self.sensor = GCSUploadSessionCompleteSensor(
task_id='sensor',
task_id='sensor_2',
bucket='test-bucket',
prefix='test-prefix/path',
inactivity_period=12,

Просмотреть файл

@ -45,7 +45,7 @@ class TestFileToWasbOperator(unittest.TestCase):
def test_init(self):
operator = FileToWasbOperator(
task_id='wasb_operator',
task_id='wasb_operator_1',
dag=self.dag,
**self._config
)
@ -58,7 +58,7 @@ class TestFileToWasbOperator(unittest.TestCase):
self.assertEqual(operator.retries, self._config['retries'])
operator = FileToWasbOperator(
task_id='wasb_operator',
task_id='wasb_operator_2',
dag=self.dag,
load_options={'timeout': 2},
**self._config

Просмотреть файл

@ -42,7 +42,7 @@ class TestWasbDeleteBlobOperator(unittest.TestCase):
def test_init(self):
operator = WasbDeleteBlobOperator(
task_id='wasb_operator',
task_id='wasb_operator_1',
dag=self.dag,
**self._config
)
@ -53,7 +53,7 @@ class TestWasbDeleteBlobOperator(unittest.TestCase):
self.assertEqual(operator.ignore_if_missing, False)
operator = WasbDeleteBlobOperator(
task_id='wasb_operator',
task_id='wasb_operator_2',
dag=self.dag,
is_prefix=True,
ignore_if_missing=True,

Просмотреть файл

@ -43,7 +43,7 @@ class TestWasbBlobSensor(unittest.TestCase):
def test_init(self):
sensor = WasbBlobSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_1',
dag=self.dag,
**self._config
)
@ -54,7 +54,7 @@ class TestWasbBlobSensor(unittest.TestCase):
self.assertEqual(sensor.timeout, self._config['timeout'])
sensor = WasbBlobSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_2',
dag=self.dag,
check_options={'timeout': 2},
**self._config
@ -94,7 +94,7 @@ class TestWasbPrefixSensor(unittest.TestCase):
def test_init(self):
sensor = WasbPrefixSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_1',
dag=self.dag,
**self._config
)
@ -105,7 +105,7 @@ class TestWasbPrefixSensor(unittest.TestCase):
self.assertEqual(sensor.timeout, self._config['timeout'])
sensor = WasbPrefixSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_2',
dag=self.dag,
check_options={'timeout': 2},
**self._config

Просмотреть файл

@ -75,7 +75,7 @@ class TestSFTPOperator(unittest.TestCase):
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
task_id="put_test_task",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
@ -89,7 +89,7 @@ class TestSFTPOperator(unittest.TestCase):
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
task_id="check_file_task",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
@ -99,7 +99,7 @@ class TestSFTPOperator(unittest.TestCase):
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(),
test_local_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
@ -178,7 +178,7 @@ class TestSFTPOperator(unittest.TestCase):
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
task_id="put_test_task",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
@ -191,7 +191,7 @@ class TestSFTPOperator(unittest.TestCase):
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
task_id="check_file_task",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
@ -201,7 +201,7 @@ class TestSFTPOperator(unittest.TestCase):
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(),
b64encode(test_local_file_content).decode('utf-8'))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
@ -362,7 +362,7 @@ class TestSFTPOperator(unittest.TestCase):
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_0",
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
@ -372,7 +372,7 @@ class TestSFTPOperator(unittest.TestCase):
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
task_1 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_1",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
@ -387,7 +387,7 @@ class TestSFTPOperator(unittest.TestCase):
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
task_2 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_2",
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
@ -402,7 +402,7 @@ class TestSFTPOperator(unittest.TestCase):
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_3",
ssh_hook=self.hook,
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,

Просмотреть файл

@ -70,16 +70,17 @@ class TestBaseSensor(unittest.TestCase):
state=State.RUNNING
)
def _make_sensor(self, return_value, **kwargs):
def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs):
poke_interval = 'poke_interval'
timeout = 'timeout'
if poke_interval not in kwargs:
kwargs[poke_interval] = 0
if timeout not in kwargs:
kwargs[timeout] = 0
sensor = DummySensor(
task_id=SENSOR_OP,
task_id=task_id,
return_value=return_value,
dag=self.dag,
**kwargs
@ -471,17 +472,20 @@ class TestBaseSensor(unittest.TestCase):
positive_poke_interval = 10
with self.assertRaises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_1',
return_value=None,
poke_interval=negative_poke_interval,
timeout=25)
with self.assertRaises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_2',
return_value=None,
poke_interval=non_number_poke_interval,
timeout=25)
self._make_sensor(
task_id='test_sensor_task_3',
return_value=None,
poke_interval=positive_poke_interval,
timeout=25)
@ -492,17 +496,20 @@ class TestBaseSensor(unittest.TestCase):
positive_timeout = 25
with self.assertRaises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_1',
return_value=None,
poke_interval=10,
timeout=negative_timeout)
with self.assertRaises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_2',
return_value=None,
poke_interval=10,
timeout=non_number_timeout)
self._make_sensor(
task_id='test_sensor_task_3',
return_value=None,
poke_interval=10,
timeout=positive_timeout)

Просмотреть файл

@ -278,7 +278,7 @@ exit 0
self.test_time_sensor()
# check that the execution_fn works
op1 = ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
task_id='test_external_task_sensor_check_delta_1',
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_date_fn=lambda dt: dt + timedelta(0),
@ -292,7 +292,7 @@ exit 0
)
# double check that the execution is being called by failing the test
op2 = ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
task_id='test_external_task_sensor_check_delta_2',
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_date_fn=lambda dt: dt + timedelta(days=1),
@ -325,7 +325,7 @@ exit 0
def test_catch_invalid_allowed_states(self):
with self.assertRaises(ValueError):
ExternalTaskSensor(
task_id='test_external_task_sensor_check',
task_id='test_external_task_sensor_check_1',
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=['invalid_state'],
@ -334,7 +334,7 @@ exit 0
with self.assertRaises(ValueError):
ExternalTaskSensor(
task_id='test_external_task_sensor_check',
task_id='test_external_task_sensor_check_2',
external_dag_id=TEST_DAG_ID,
external_task_id=None,
allowed_states=['invalid_state'],

Просмотреть файл

@ -55,7 +55,7 @@ class TestSqlSensor(TestHiveEnvironment):
@pytest.mark.backend("mysql")
def test_sql_sensor_mysql(self):
op1 = SqlSensor(
task_id='sql_sensor_check',
task_id='sql_sensor_check_1',
conn_id='mysql_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
@ -63,7 +63,7 @@ class TestSqlSensor(TestHiveEnvironment):
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
op2 = SqlSensor(
task_id='sql_sensor_check',
task_id='sql_sensor_check_2',
conn_id='mysql_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
@ -74,7 +74,7 @@ class TestSqlSensor(TestHiveEnvironment):
@pytest.mark.backend("postgres")
def test_sql_sensor_postgres(self):
op1 = SqlSensor(
task_id='sql_sensor_check',
task_id='sql_sensor_check_1',
conn_id='postgres_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
@ -82,7 +82,7 @@ class TestSqlSensor(TestHiveEnvironment):
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
op2 = SqlSensor(
task_id='sql_sensor_check',
task_id='sql_sensor_check_2',
conn_id='postgres_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],