[AIRFLOW-5807] Move SFTP from contrib to providers. (#6464)
* [AIRFLOW-5807] Move SFTP from contrib to core
This commit is contained in:
Родитель
25830a01a9
Коммит
69629a5a94
|
@ -16,274 +16,17 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import stat
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import warnings
|
||||
|
||||
import pysftp
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook # noqa
|
||||
|
||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||
|
||||
|
||||
class SFTPHook(SSHHook):
|
||||
"""
|
||||
This hook is inherited from SSH hook. Please refer to SSH hook for the input
|
||||
arguments.
|
||||
|
||||
Interact with SFTP. Aims to be interchangeable with FTPHook.
|
||||
|
||||
:Pitfalls::
|
||||
|
||||
- In contrast with FTPHook describe_directory only returns size, type and
|
||||
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
|
||||
unique.
|
||||
- retrieve_file and store_file only take a local full path and not a
|
||||
buffer.
|
||||
- If no mode is passed to create_directory it will be created with 777
|
||||
permissions.
|
||||
|
||||
Errors that may occur throughout but should be handled downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
|
||||
kwargs['ssh_conn_id'] = ftp_conn_id
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conn = None
|
||||
self.private_key_pass = None
|
||||
|
||||
# Fail for unverified hosts, unless this is explicitly allowed
|
||||
self.no_host_key_check = False
|
||||
|
||||
if self.ssh_conn_id is not None:
|
||||
conn = self.get_connection(self.ssh_conn_id)
|
||||
if conn.extra is not None:
|
||||
extra_options = conn.extra_dejson
|
||||
if 'private_key_pass' in extra_options:
|
||||
self.private_key_pass = extra_options.get('private_key_pass', None)
|
||||
|
||||
# For backward compatibility
|
||||
# TODO: remove in Airflow 2.1
|
||||
import warnings
|
||||
if 'ignore_hostkey_verification' in extra_options:
|
||||
warnings.warn(
|
||||
'Extra option `ignore_hostkey_verification` is deprecated.'
|
||||
'Please use `no_host_key_check` instead.'
|
||||
'This option will be removed in Airflow 2.1',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.no_host_key_check = str(
|
||||
extra_options['ignore_hostkey_verification']
|
||||
).lower() == 'true'
|
||||
|
||||
if 'no_host_key_check' in extra_options:
|
||||
self.no_host_key_check = str(
|
||||
extra_options['no_host_key_check']).lower() == 'true'
|
||||
|
||||
if 'private_key' in extra_options:
|
||||
warnings.warn(
|
||||
'Extra option `private_key` is deprecated.'
|
||||
'Please use `key_file` instead.'
|
||||
'This option will be removed in Airflow 2.1',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.key_file = extra_options.get('private_key')
|
||||
|
||||
def get_conn(self) -> pysftp.Connection:
|
||||
"""
|
||||
Returns an SFTP connection object
|
||||
"""
|
||||
if self.conn is None:
|
||||
cnopts = pysftp.CnOpts()
|
||||
if self.no_host_key_check:
|
||||
cnopts.hostkeys = None
|
||||
cnopts.compression = self.compress
|
||||
conn_params = {
|
||||
'host': self.remote_host,
|
||||
'port': self.port,
|
||||
'username': self.username,
|
||||
'cnopts': cnopts
|
||||
}
|
||||
if self.password and self.password.strip():
|
||||
conn_params['password'] = self.password
|
||||
if self.key_file:
|
||||
conn_params['private_key'] = self.key_file
|
||||
if self.private_key_pass:
|
||||
conn_params['private_key_pass'] = self.private_key_pass
|
||||
|
||||
self.conn = pysftp.Connection(**conn_params)
|
||||
return self.conn
|
||||
|
||||
def close_conn(self) -> None:
|
||||
"""
|
||||
Closes the connection. An error will occur if the
|
||||
connection wasnt ever opened.
|
||||
"""
|
||||
conn = self.conn
|
||||
conn.close() # type: ignore
|
||||
self.conn = None
|
||||
|
||||
def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Returns a dictionary of {filename: {attributes}} for all files
|
||||
on the remote system (where the MLSD command is supported).
|
||||
|
||||
:param path: full path to the remote directory
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
flist = conn.listdir_attr(path)
|
||||
files = {}
|
||||
for f in flist:
|
||||
modify = datetime.datetime.fromtimestamp(
|
||||
f.st_mtime).strftime('%Y%m%d%H%M%S')
|
||||
files[f.filename] = {
|
||||
'size': f.st_size,
|
||||
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
|
||||
'modify': modify}
|
||||
return files
|
||||
|
||||
def list_directory(self, path: str) -> List[str]:
|
||||
"""
|
||||
Returns a list of files on the remote system.
|
||||
|
||||
:param path: full path to the remote directory to list
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
files = conn.listdir(path)
|
||||
return files
|
||||
|
||||
def create_directory(self, path: str, mode: int = 777) -> None:
|
||||
"""
|
||||
Creates a directory on the remote system.
|
||||
|
||||
:param path: full path to the remote directory to create
|
||||
:type path: str
|
||||
:param mode: int representation of octal mode for directory
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.makedirs(path, mode)
|
||||
|
||||
def delete_directory(self, path: str) -> None:
|
||||
"""
|
||||
Deletes a directory on the remote system.
|
||||
|
||||
:param path: full path to the remote directory to delete
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.rmdir(path)
|
||||
|
||||
def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
|
||||
"""
|
||||
Transfers the remote file to a local location.
|
||||
If local_full_path is a string path, the file will be put
|
||||
at that location
|
||||
|
||||
:param remote_full_path: full path to the remote file
|
||||
:type remote_full_path: str
|
||||
:param local_full_path: full path to the local file
|
||||
:type local_full_path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
self.log.info('Retrieving file from FTP: %s', remote_full_path)
|
||||
conn.get(remote_full_path, local_full_path)
|
||||
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
|
||||
|
||||
def store_file(self, remote_full_path: str, local_full_path: str) -> None:
|
||||
"""
|
||||
Transfers a local file to the remote location.
|
||||
If local_full_path_or_buffer is a string path, the file will be read
|
||||
from that location
|
||||
|
||||
:param remote_full_path: full path to the remote file
|
||||
:type remote_full_path: str
|
||||
:param local_full_path: full path to the local file
|
||||
:type local_full_path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.put(local_full_path, remote_full_path)
|
||||
|
||||
def delete_file(self, path: str) -> None:
|
||||
"""
|
||||
Removes a file on the FTP Server
|
||||
|
||||
:param path: full path to the remote file
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.remove(path)
|
||||
|
||||
def get_mod_time(self, path: str) -> str:
|
||||
conn = self.get_conn()
|
||||
ftp_mdtm = conn.stat(path).st_mtime
|
||||
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
|
||||
|
||||
def path_exists(self, path: str) -> bool:
|
||||
"""
|
||||
Returns True if a remote entity exists
|
||||
|
||||
:param path: full path to the remote file or directory
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
return conn.exists(path)
|
||||
|
||||
@staticmethod
|
||||
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Return True if given path starts with prefix (if set) and ends with delimiter (if set).
|
||||
|
||||
:param path: path to be checked
|
||||
:type path: str
|
||||
:param prefix: if set path will be checked is starting with prefix
|
||||
:type prefix: str
|
||||
:param delimiter: if set path will be checked is ending with suffix
|
||||
:type delimiter: str
|
||||
:return: bool
|
||||
"""
|
||||
if prefix is not None and not path.startswith(prefix):
|
||||
return False
|
||||
if delimiter is not None and not path.endswith(delimiter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_tree_map(
|
||||
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""
|
||||
Return tuple with recursive lists of files, directories and unknown paths from given path.
|
||||
It is possible to filter results by giving prefix and/or delimiter parameters.
|
||||
|
||||
:param path: path from which tree will be built
|
||||
:type path: str
|
||||
:param prefix: if set paths will be added if start with prefix
|
||||
:type prefix: str
|
||||
:param delimiter: if set paths will be added if end with delimiter
|
||||
:type delimiter: str
|
||||
:return: tuple with list of files, dirs and unknown items
|
||||
:rtype: Tuple[List[str], List[str], List[str]]
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
|
||||
|
||||
def append_matching_path_callback(list_):
|
||||
return (
|
||||
lambda item: list_.append(item)
|
||||
if self._is_path_match(item, prefix, delimiter)
|
||||
else None
|
||||
)
|
||||
|
||||
conn.walktree(
|
||||
remotepath=path,
|
||||
fcallback=append_matching_path_callback(files),
|
||||
dcallback=append_matching_path_callback(dirs),
|
||||
ucallback=append_matching_path_callback(unknowns),
|
||||
recurse=True,
|
||||
)
|
||||
|
||||
return files, dirs, unknowns
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
|
|
@ -16,165 +16,17 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import os
|
||||
"""
|
||||
This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.
|
||||
"""
|
||||
|
||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
import warnings
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.sftp.operators.sftp_operator import SFTPOperator # noqa
|
||||
|
||||
class SFTPOperation:
|
||||
PUT = 'put'
|
||||
GET = 'get'
|
||||
|
||||
|
||||
class SFTPOperator(BaseOperator):
|
||||
"""
|
||||
SFTPOperator for transferring files from remote host to local or vice a versa.
|
||||
This operator uses ssh_hook to open sftp transport channel that serve as basis
|
||||
for file transfer.
|
||||
|
||||
:param ssh_hook: predefined ssh_hook to use for remote execution.
|
||||
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
|
||||
:type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
|
||||
:param ssh_conn_id: connection id from airflow Connections.
|
||||
`ssh_conn_id` will be ignored if `ssh_hook` is provided.
|
||||
:type ssh_conn_id: str
|
||||
:param remote_host: remote host to connect (templated)
|
||||
Nullable. If provided, it will replace the `remote_host` which was
|
||||
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
|
||||
:type remote_host: str
|
||||
:param local_filepath: local file path to get or put. (templated)
|
||||
:type local_filepath: str
|
||||
:param remote_filepath: remote file path to get or put. (templated)
|
||||
:type remote_filepath: str
|
||||
:param operation: specify operation 'get' or 'put', defaults to put
|
||||
:type operation: str
|
||||
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
|
||||
:type confirm: bool
|
||||
:param create_intermediate_dirs: create missing intermediate directories when
|
||||
copying from remote to local and vice-versa. Default is False.
|
||||
|
||||
Example: The following task would copy ``file.txt`` to the remote host
|
||||
at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
|
||||
don't exist. If the parameter is not passed it would error as the directory
|
||||
does not exist. ::
|
||||
|
||||
put_file = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_conn_id="ssh_default",
|
||||
local_filepath="/tmp/file.txt",
|
||||
remote_filepath="/tmp/tmp1/tmp2/file.txt",
|
||||
operation="put",
|
||||
create_intermediate_dirs=True,
|
||||
dag=dag
|
||||
)
|
||||
|
||||
:type create_intermediate_dirs: bool
|
||||
"""
|
||||
template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
ssh_hook=None,
|
||||
ssh_conn_id=None,
|
||||
remote_host=None,
|
||||
local_filepath=None,
|
||||
remote_filepath=None,
|
||||
operation=SFTPOperation.PUT,
|
||||
confirm=True,
|
||||
create_intermediate_dirs=False,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ssh_hook = ssh_hook
|
||||
self.ssh_conn_id = ssh_conn_id
|
||||
self.remote_host = remote_host
|
||||
self.local_filepath = local_filepath
|
||||
self.remote_filepath = remote_filepath
|
||||
self.operation = operation
|
||||
self.confirm = confirm
|
||||
self.create_intermediate_dirs = create_intermediate_dirs
|
||||
if not (self.operation.lower() == SFTPOperation.GET or
|
||||
self.operation.lower() == SFTPOperation.PUT):
|
||||
raise TypeError("unsupported operation value {0}, expected {1} or {2}"
|
||||
.format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
|
||||
|
||||
def execute(self, context):
|
||||
file_msg = None
|
||||
try:
|
||||
if self.ssh_conn_id:
|
||||
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
|
||||
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
|
||||
else:
|
||||
self.log.info("ssh_hook is not provided or invalid. " +
|
||||
"Trying ssh_conn_id to create SSHHook.")
|
||||
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
|
||||
|
||||
if not self.ssh_hook:
|
||||
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
|
||||
|
||||
if self.remote_host is not None:
|
||||
self.log.info("remote_host is provided explicitly. " +
|
||||
"It will replace the remote_host which was defined " +
|
||||
"in ssh_hook or predefined in connection of ssh_conn_id.")
|
||||
self.ssh_hook.remote_host = self.remote_host
|
||||
|
||||
with self.ssh_hook.get_conn() as ssh_client:
|
||||
sftp_client = ssh_client.open_sftp()
|
||||
if self.operation.lower() == SFTPOperation.GET:
|
||||
local_folder = os.path.dirname(self.local_filepath)
|
||||
if self.create_intermediate_dirs:
|
||||
# Create Intermediate Directories if it doesn't exist
|
||||
try:
|
||||
os.makedirs(local_folder)
|
||||
except OSError:
|
||||
if not os.path.isdir(local_folder):
|
||||
raise
|
||||
file_msg = "from {0} to {1}".format(self.remote_filepath,
|
||||
self.local_filepath)
|
||||
self.log.info("Starting to transfer %s", file_msg)
|
||||
sftp_client.get(self.remote_filepath, self.local_filepath)
|
||||
else:
|
||||
remote_folder = os.path.dirname(self.remote_filepath)
|
||||
if self.create_intermediate_dirs:
|
||||
_make_intermediate_dirs(
|
||||
sftp_client=sftp_client,
|
||||
remote_directory=remote_folder,
|
||||
)
|
||||
file_msg = "from {0} to {1}".format(self.local_filepath,
|
||||
self.remote_filepath)
|
||||
self.log.info("Starting to transfer file %s", file_msg)
|
||||
sftp_client.put(self.local_filepath,
|
||||
self.remote_filepath,
|
||||
confirm=self.confirm)
|
||||
|
||||
except Exception as e:
|
||||
raise AirflowException("Error while transferring {0}, error: {1}"
|
||||
.format(file_msg, str(e)))
|
||||
|
||||
return self.local_filepath
|
||||
|
||||
|
||||
def _make_intermediate_dirs(sftp_client, remote_directory):
|
||||
"""
|
||||
Create all the intermediate directories in a remote host
|
||||
|
||||
:param sftp_client: A Paramiko SFTP client.
|
||||
:param remote_directory: Absolute Path of the directory containing the file
|
||||
:return:
|
||||
"""
|
||||
if remote_directory == '/':
|
||||
sftp_client.chdir('/')
|
||||
return
|
||||
if remote_directory == '':
|
||||
return
|
||||
try:
|
||||
sftp_client.chdir(remote_directory)
|
||||
except OSError:
|
||||
dirname, basename = os.path.split(remote_directory.rstrip('/'))
|
||||
_make_intermediate_dirs(sftp_client, dirname)
|
||||
sftp_client.mkdir(basename)
|
||||
sftp_client.chdir(basename)
|
||||
return
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
|
|
@ -16,40 +16,17 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.
|
||||
"""
|
||||
|
||||
from paramiko import SFTP_NO_SUCH_FILE
|
||||
import warnings
|
||||
|
||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
# pylint: disable=unused-import
|
||||
from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor # noqa
|
||||
|
||||
|
||||
class SFTPSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a file or directory to be present on SFTP.
|
||||
|
||||
:param path: Remote file or directory path
|
||||
:type path: str
|
||||
:param sftp_conn_id: The connection to run the sensor against
|
||||
:type sftp_conn_id: str
|
||||
"""
|
||||
template_fields = ('path',)
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.path = path
|
||||
self.hook = None
|
||||
self.sftp_conn_id = sftp_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
self.hook = SFTPHook(self.sftp_conn_id)
|
||||
self.log.info('Poking for %s', self.path)
|
||||
try:
|
||||
self.hook.get_mod_time(self.path)
|
||||
except OSError as e:
|
||||
if e.errno != SFTP_NO_SUCH_FILE:
|
||||
raise e
|
||||
return False
|
||||
self.hook.close_conn()
|
||||
return True
|
||||
warnings.warn(
|
||||
"This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
|
|
@ -24,9 +24,9 @@ from tempfile import NamedTemporaryFile
|
|||
from typing import Optional
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
||||
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
WILDCARD = "*"
|
||||
|
|
|
@ -23,9 +23,9 @@ from tempfile import NamedTemporaryFile
|
|||
from typing import Optional, Union
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
||||
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
WILDCARD = "*"
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,297 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
This module contains SFTP hook.
|
||||
"""
|
||||
import datetime
|
||||
import stat
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pysftp
|
||||
|
||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||
|
||||
|
||||
class SFTPHook(SSHHook):
|
||||
"""
|
||||
This hook is inherited from SSH hook. Please refer to SSH hook for the input
|
||||
arguments.
|
||||
|
||||
Interact with SFTP. Aims to be interchangeable with FTPHook.
|
||||
|
||||
:Pitfalls::
|
||||
|
||||
- In contrast with FTPHook describe_directory only returns size, type and
|
||||
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
|
||||
unique.
|
||||
- retrieve_file and store_file only take a local full path and not a
|
||||
buffer.
|
||||
- If no mode is passed to create_directory it will be created with 777
|
||||
permissions.
|
||||
|
||||
Errors that may occur throughout but should be handled downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
|
||||
kwargs['ssh_conn_id'] = ftp_conn_id
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conn = None
|
||||
self.private_key_pass = None
|
||||
|
||||
# Fail for unverified hosts, unless this is explicitly allowed
|
||||
self.no_host_key_check = False
|
||||
|
||||
if self.ssh_conn_id is not None:
|
||||
conn = self.get_connection(self.ssh_conn_id)
|
||||
if conn.extra is not None:
|
||||
extra_options = conn.extra_dejson
|
||||
if 'private_key_pass' in extra_options:
|
||||
self.private_key_pass = extra_options.get('private_key_pass', None)
|
||||
|
||||
# For backward compatibility
|
||||
# TODO: remove in Airflow 2.1
|
||||
import warnings
|
||||
if 'ignore_hostkey_verification' in extra_options:
|
||||
warnings.warn(
|
||||
'Extra option `ignore_hostkey_verification` is deprecated.'
|
||||
'Please use `no_host_key_check` instead.'
|
||||
'This option will be removed in Airflow 2.1',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.no_host_key_check = str(
|
||||
extra_options['ignore_hostkey_verification']
|
||||
).lower() == 'true'
|
||||
|
||||
if 'no_host_key_check' in extra_options:
|
||||
self.no_host_key_check = str(
|
||||
extra_options['no_host_key_check']).lower() == 'true'
|
||||
|
||||
if 'private_key' in extra_options:
|
||||
warnings.warn(
|
||||
'Extra option `private_key` is deprecated.'
|
||||
'Please use `key_file` instead.'
|
||||
'This option will be removed in Airflow 2.1',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.key_file = extra_options.get('private_key')
|
||||
|
||||
def get_conn(self) -> pysftp.Connection:
|
||||
"""
|
||||
Returns an SFTP connection object
|
||||
"""
|
||||
if self.conn is None:
|
||||
cnopts = pysftp.CnOpts()
|
||||
if self.no_host_key_check:
|
||||
cnopts.hostkeys = None
|
||||
cnopts.compression = self.compress
|
||||
conn_params = {
|
||||
'host': self.remote_host,
|
||||
'port': self.port,
|
||||
'username': self.username,
|
||||
'cnopts': cnopts
|
||||
}
|
||||
if self.password and self.password.strip():
|
||||
conn_params['password'] = self.password
|
||||
if self.key_file:
|
||||
conn_params['private_key'] = self.key_file
|
||||
if self.private_key_pass:
|
||||
conn_params['private_key_pass'] = self.private_key_pass
|
||||
|
||||
self.conn = pysftp.Connection(**conn_params)
|
||||
return self.conn
|
||||
|
||||
def close_conn(self) -> None:
|
||||
"""
|
||||
Closes the connection. An error will occur if the
|
||||
connection wasnt ever opened.
|
||||
"""
|
||||
conn = self.conn
|
||||
conn.close() # type: ignore
|
||||
self.conn = None
|
||||
|
||||
def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Returns a dictionary of {filename: {attributes}} for all files
|
||||
on the remote system (where the MLSD command is supported).
|
||||
|
||||
:param path: full path to the remote directory
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
flist = conn.listdir_attr(path)
|
||||
files = {}
|
||||
for f in flist:
|
||||
modify = datetime.datetime.fromtimestamp(
|
||||
f.st_mtime).strftime('%Y%m%d%H%M%S')
|
||||
files[f.filename] = {
|
||||
'size': f.st_size,
|
||||
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
|
||||
'modify': modify}
|
||||
return files
|
||||
|
||||
def list_directory(self, path: str) -> List[str]:
|
||||
"""
|
||||
Returns a list of files on the remote system.
|
||||
|
||||
:param path: full path to the remote directory to list
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
files = conn.listdir(path)
|
||||
return files
|
||||
|
||||
def create_directory(self, path: str, mode: int = 777) -> None:
|
||||
"""
|
||||
Creates a directory on the remote system.
|
||||
|
||||
:param path: full path to the remote directory to create
|
||||
:type path: str
|
||||
:param mode: int representation of octal mode for directory
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.makedirs(path, mode)
|
||||
|
||||
def delete_directory(self, path: str) -> None:
|
||||
"""
|
||||
Deletes a directory on the remote system.
|
||||
|
||||
:param path: full path to the remote directory to delete
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.rmdir(path)
|
||||
|
||||
def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
|
||||
"""
|
||||
Transfers the remote file to a local location.
|
||||
If local_full_path is a string path, the file will be put
|
||||
at that location
|
||||
|
||||
:param remote_full_path: full path to the remote file
|
||||
:type remote_full_path: str
|
||||
:param local_full_path: full path to the local file
|
||||
:type local_full_path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
self.log.info('Retrieving file from FTP: %s', remote_full_path)
|
||||
conn.get(remote_full_path, local_full_path)
|
||||
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
|
||||
|
||||
def store_file(self, remote_full_path: str, local_full_path: str) -> None:
|
||||
"""
|
||||
Transfers a local file to the remote location.
|
||||
If local_full_path_or_buffer is a string path, the file will be read
|
||||
from that location
|
||||
|
||||
:param remote_full_path: full path to the remote file
|
||||
:type remote_full_path: str
|
||||
:param local_full_path: full path to the local file
|
||||
:type local_full_path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.put(local_full_path, remote_full_path)
|
||||
|
||||
def delete_file(self, path: str) -> None:
|
||||
"""
|
||||
Removes a file on the FTP Server
|
||||
|
||||
:param path: full path to the remote file
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn.remove(path)
|
||||
|
||||
def get_mod_time(self, path: str) -> str:
|
||||
"""
|
||||
Returns modification time.
|
||||
|
||||
:param path: full path to the remote file
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
ftp_mdtm = conn.stat(path).st_mtime
|
||||
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
|
||||
|
||||
def path_exists(self, path: str) -> bool:
|
||||
"""
|
||||
Returns True if a remote entity exists
|
||||
|
||||
:param path: full path to the remote file or directory
|
||||
:type path: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
return conn.exists(path)
|
||||
|
||||
@staticmethod
|
||||
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Return True if given path starts with prefix (if set) and ends with delimiter (if set).
|
||||
|
||||
:param path: path to be checked
|
||||
:type path: str
|
||||
:param prefix: if set path will be checked is starting with prefix
|
||||
:type prefix: str
|
||||
:param delimiter: if set path will be checked is ending with suffix
|
||||
:type delimiter: str
|
||||
:return: bool
|
||||
"""
|
||||
if prefix is not None and not path.startswith(prefix):
|
||||
return False
|
||||
if delimiter is not None and not path.endswith(delimiter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_tree_map(
|
||||
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""
|
||||
Return tuple with recursive lists of files, directories and unknown paths from given path.
|
||||
It is possible to filter results by giving prefix and/or delimiter parameters.
|
||||
|
||||
:param path: path from which tree will be built
|
||||
:type path: str
|
||||
:param prefix: if set paths will be added if start with prefix
|
||||
:type prefix: str
|
||||
:param delimiter: if set paths will be added if end with delimiter
|
||||
:type delimiter: str
|
||||
:return: tuple with list of files, dirs and unknown items
|
||||
:rtype: Tuple[List[str], List[str], List[str]]
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
|
||||
|
||||
def append_matching_path_callback(list_):
|
||||
return (
|
||||
lambda item: list_.append(item)
|
||||
if self._is_path_match(item, prefix, delimiter)
|
||||
else None
|
||||
)
|
||||
|
||||
conn.walktree(
|
||||
remotepath=path,
|
||||
fcallback=append_matching_path_callback(files),
|
||||
dcallback=append_matching_path_callback(dirs),
|
||||
ucallback=append_matching_path_callback(unknowns),
|
||||
recurse=True,
|
||||
)
|
||||
|
||||
return files, dirs, unknowns
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,184 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
This module contains SFTP operator.
|
||||
"""
|
||||
import os
|
||||
|
||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
class SFTPOperation:
|
||||
PUT = 'put'
|
||||
GET = 'get'
|
||||
|
||||
|
||||
class SFTPOperator(BaseOperator):
|
||||
"""
|
||||
SFTPOperator for transferring files from remote host to local or vice a versa.
|
||||
This operator uses ssh_hook to open sftp transport channel that serve as basis
|
||||
for file transfer.
|
||||
|
||||
:param ssh_hook: predefined ssh_hook to use for remote execution.
|
||||
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
|
||||
:type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
|
||||
:param ssh_conn_id: connection id from airflow Connections.
|
||||
`ssh_conn_id` will be ignored if `ssh_hook` is provided.
|
||||
:type ssh_conn_id: str
|
||||
:param remote_host: remote host to connect (templated)
|
||||
Nullable. If provided, it will replace the `remote_host` which was
|
||||
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
|
||||
:type remote_host: str
|
||||
:param local_filepath: local file path to get or put. (templated)
|
||||
:type local_filepath: str
|
||||
:param remote_filepath: remote file path to get or put. (templated)
|
||||
:type remote_filepath: str
|
||||
:param operation: specify operation 'get' or 'put', defaults to put
|
||||
:type operation: str
|
||||
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
|
||||
:type confirm: bool
|
||||
:param create_intermediate_dirs: create missing intermediate directories when
|
||||
copying from remote to local and vice-versa. Default is False.
|
||||
|
||||
Example: The following task would copy ``file.txt`` to the remote host
|
||||
at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
|
||||
don't exist. If the parameter is not passed it would error as the directory
|
||||
does not exist. ::
|
||||
|
||||
put_file = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_conn_id="ssh_default",
|
||||
local_filepath="/tmp/file.txt",
|
||||
remote_filepath="/tmp/tmp1/tmp2/file.txt",
|
||||
operation="put",
|
||||
create_intermediate_dirs=True,
|
||||
dag=dag
|
||||
)
|
||||
|
||||
:type create_intermediate_dirs: bool
|
||||
"""
|
||||
template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
ssh_hook=None,
|
||||
ssh_conn_id=None,
|
||||
remote_host=None,
|
||||
local_filepath=None,
|
||||
remote_filepath=None,
|
||||
operation=SFTPOperation.PUT,
|
||||
confirm=True,
|
||||
create_intermediate_dirs=False,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ssh_hook = ssh_hook
|
||||
self.ssh_conn_id = ssh_conn_id
|
||||
self.remote_host = remote_host
|
||||
self.local_filepath = local_filepath
|
||||
self.remote_filepath = remote_filepath
|
||||
self.operation = operation
|
||||
self.confirm = confirm
|
||||
self.create_intermediate_dirs = create_intermediate_dirs
|
||||
if not (self.operation.lower() == SFTPOperation.GET or
|
||||
self.operation.lower() == SFTPOperation.PUT):
|
||||
raise TypeError("unsupported operation value {0}, expected {1} or {2}"
|
||||
.format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
|
||||
|
||||
def execute(self, context):
|
||||
file_msg = None
|
||||
try:
|
||||
if self.ssh_conn_id:
|
||||
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
|
||||
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
|
||||
else:
|
||||
self.log.info("ssh_hook is not provided or invalid. "
|
||||
"Trying ssh_conn_id to create SSHHook.")
|
||||
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
|
||||
|
||||
if not self.ssh_hook:
|
||||
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
|
||||
|
||||
if self.remote_host is not None:
|
||||
self.log.info("remote_host is provided explicitly. "
|
||||
"It will replace the remote_host which was defined "
|
||||
"in ssh_hook or predefined in connection of ssh_conn_id.")
|
||||
self.ssh_hook.remote_host = self.remote_host
|
||||
|
||||
with self.ssh_hook.get_conn() as ssh_client:
|
||||
sftp_client = ssh_client.open_sftp()
|
||||
if self.operation.lower() == SFTPOperation.GET:
|
||||
local_folder = os.path.dirname(self.local_filepath)
|
||||
if self.create_intermediate_dirs:
|
||||
# Create Intermediate Directories if it doesn't exist
|
||||
try:
|
||||
os.makedirs(local_folder)
|
||||
except OSError:
|
||||
if not os.path.isdir(local_folder):
|
||||
raise
|
||||
file_msg = "from {0} to {1}".format(self.remote_filepath,
|
||||
self.local_filepath)
|
||||
self.log.info("Starting to transfer %s", file_msg)
|
||||
sftp_client.get(self.remote_filepath, self.local_filepath)
|
||||
else:
|
||||
remote_folder = os.path.dirname(self.remote_filepath)
|
||||
if self.create_intermediate_dirs:
|
||||
_make_intermediate_dirs(
|
||||
sftp_client=sftp_client,
|
||||
remote_directory=remote_folder,
|
||||
)
|
||||
file_msg = "from {0} to {1}".format(self.local_filepath,
|
||||
self.remote_filepath)
|
||||
self.log.info("Starting to transfer file %s", file_msg)
|
||||
sftp_client.put(self.local_filepath,
|
||||
self.remote_filepath,
|
||||
confirm=self.confirm)
|
||||
|
||||
except Exception as e:
|
||||
raise AirflowException("Error while transferring {0}, error: {1}"
|
||||
.format(file_msg, str(e)))
|
||||
|
||||
return self.local_filepath
|
||||
|
||||
|
||||
def _make_intermediate_dirs(sftp_client, remote_directory):
|
||||
"""
|
||||
Create all the intermediate directories in a remote host
|
||||
|
||||
:param sftp_client: A Paramiko SFTP client.
|
||||
:param remote_directory: Absolute Path of the directory containing the file
|
||||
:return:
|
||||
"""
|
||||
if remote_directory == '/':
|
||||
sftp_client.chdir('/')
|
||||
return
|
||||
if remote_directory == '':
|
||||
return
|
||||
try:
|
||||
sftp_client.chdir(remote_directory)
|
||||
except OSError:
|
||||
dirname, basename = os.path.split(remote_directory.rstrip('/'))
|
||||
_make_intermediate_dirs(sftp_client, dirname)
|
||||
sftp_client.mkdir(basename)
|
||||
sftp_client.chdir(basename)
|
||||
return
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,57 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
This module contains SFTP sensor.
|
||||
"""
|
||||
from paramiko import SFTP_NO_SUCH_FILE
|
||||
|
||||
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SFTPSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a file or directory to be present on SFTP.
|
||||
|
||||
:param path: Remote file or directory path
|
||||
:type path: str
|
||||
:param sftp_conn_id: The connection to run the sensor against
|
||||
:type sftp_conn_id: str
|
||||
"""
|
||||
template_fields = ('path',)
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.path = path
|
||||
self.hook = None
|
||||
self.sftp_conn_id = sftp_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
self.hook = SFTPHook(self.sftp_conn_id)
|
||||
self.log.info('Poking for %s', self.path)
|
||||
try:
|
||||
self.hook.get_mod_time(self.path)
|
||||
except OSError as e:
|
||||
if e.errno != SFTP_NO_SUCH_FILE:
|
||||
raise e
|
||||
return False
|
||||
self.hook.close_conn()
|
||||
return True
|
|
@ -76,6 +76,8 @@ All operators are in the following packages:
|
|||
|
||||
airflow/providers/amazon/aws/sensors/index
|
||||
|
||||
airflow/providers/apache/cassandra/sensors/index
|
||||
|
||||
airflow/providers/google/cloud/operators/index
|
||||
|
||||
airflow/providers/google/cloud/sensors/index
|
||||
|
@ -84,7 +86,9 @@ All operators are in the following packages:
|
|||
|
||||
airflow/providers/google/marketing_platform/sensors/index
|
||||
|
||||
airflow/providers/apache/cassandra/sensors/index
|
||||
airflow/providers/sftp/operators/index
|
||||
|
||||
airflow/providers/sftp/sensors/index
|
||||
|
||||
Hooks
|
||||
-----
|
||||
|
@ -117,6 +121,8 @@ All hooks are in the following packages:
|
|||
|
||||
airflow/providers/apache/cassandra/hooks/index
|
||||
|
||||
airflow/providers/sftp/hooks/index
|
||||
|
||||
Executors
|
||||
---------
|
||||
Executors are the mechanism by which task instances get run. All executors are
|
||||
|
|
|
@ -233,6 +233,7 @@ exclude_patterns = [
|
|||
'_api/airflow/providers/amazon/aws/example_dags',
|
||||
'_api/airflow/providers/apache/index.rst',
|
||||
'_api/airflow/providers/apache/cassandra/index.rst',
|
||||
'_api/airflow/providers/sftp/index.rst',
|
||||
'_api/enums/index.rst',
|
||||
'_api/json_schema/index.rst',
|
||||
'_api/base_serialization/index.rst',
|
||||
|
|
|
@ -1246,9 +1246,9 @@ communication protocols or interface.
|
|||
|
||||
* - `SSH File Transfer Protocol (SFTP) <https://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/>`__
|
||||
-
|
||||
- :mod:`airflow.contrib.hooks.sftp_hook`
|
||||
- :mod:`airflow.contrib.operators.sftp_operator`
|
||||
- :mod:`airflow.contrib.sensors.sftp_sensor`
|
||||
- :mod:`airflow.providers.sftp.hooks.sftp_hook`
|
||||
- :mod:`airflow.providers.sftp.operators.sftp_operator`
|
||||
- :mod:`airflow.providers.sftp.sensors.sftp_sensor`
|
||||
|
||||
* - `Secure Shell (SSH) <https://tools.ietf.org/html/rfc4251>`__
|
||||
-
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
./airflow/contrib/hooks/qubole_check_hook.py
|
||||
./airflow/contrib/hooks/sagemaker_hook.py
|
||||
./airflow/contrib/hooks/segment_hook.py
|
||||
./airflow/contrib/hooks/sftp_hook.py
|
||||
./airflow/contrib/hooks/slack_webhook_hook.py
|
||||
./airflow/contrib/hooks/snowflake_hook.py
|
||||
./airflow/contrib/hooks/spark_jdbc_hook.py
|
||||
|
@ -65,7 +64,6 @@
|
|||
./airflow/contrib/operators/sagemaker_transform_operator.py
|
||||
./airflow/contrib/operators/sagemaker_tuning_operator.py
|
||||
./airflow/contrib/operators/segment_track_event_operator.py
|
||||
./airflow/contrib/operators/sftp_operator.py
|
||||
./airflow/contrib/operators/sftp_to_s3_operator.py
|
||||
./airflow/contrib/operators/slack_webhook_operator.py
|
||||
./airflow/contrib/operators/snowflake_operator.py
|
||||
|
@ -100,7 +98,6 @@
|
|||
./airflow/contrib/sensors/sagemaker_training_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_transform_sensor.py
|
||||
./airflow/contrib/sensors/sagemaker_tuning_sensor.py
|
||||
./airflow/contrib/sensors/sftp_sensor.py
|
||||
./airflow/contrib/sensors/wasb_sensor.py
|
||||
./airflow/contrib/sensors/weekday_sensor.py
|
||||
./airflow/hooks/dbapi_hook.py
|
||||
|
|
|
@ -17,231 +17,17 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
from unittest import mock
|
||||
import importlib
|
||||
from unittest import TestCase
|
||||
|
||||
import pysftp
|
||||
from parameterized import parameterized
|
||||
|
||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
||||
from airflow.models import Connection
|
||||
from airflow.utils.db import provide_session
|
||||
|
||||
TMP_PATH = '/tmp'
|
||||
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
|
||||
SUB_DIR = "sub_dir"
|
||||
TMP_FILE_FOR_TESTS = 'test_file.txt'
|
||||
|
||||
SFTP_CONNECTION_USER = "root"
|
||||
OLD_PATH = "airflow.contrib.hooks.sftp_hook"
|
||||
NEW_PATH = "airflow.providers.sftp.hooks.sftp_hook"
|
||||
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
|
||||
|
||||
|
||||
class TestSFTPHook(unittest.TestCase):
|
||||
|
||||
@provide_session
|
||||
def update_connection(self, login, session=None):
|
||||
connection = (session.query(Connection).
|
||||
filter(Connection.conn_id == "sftp_default")
|
||||
.first())
|
||||
old_login = connection.login
|
||||
connection.login = login
|
||||
session.commit()
|
||||
return old_login
|
||||
|
||||
def setUp(self):
|
||||
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
|
||||
self.hook = SFTPHook()
|
||||
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
|
||||
|
||||
with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
|
||||
file.write('Test file')
|
||||
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
|
||||
file.write('Test file')
|
||||
|
||||
def test_get_conn(self):
|
||||
output = self.hook.get_conn()
|
||||
self.assertEqual(type(output), pysftp.Connection)
|
||||
|
||||
def test_close_conn(self):
|
||||
self.hook.conn = self.hook.get_conn()
|
||||
self.assertTrue(self.hook.conn is not None)
|
||||
self.hook.close_conn()
|
||||
self.assertTrue(self.hook.conn is None)
|
||||
|
||||
def test_describe_directory(self):
|
||||
output = self.hook.describe_directory(TMP_PATH)
|
||||
self.assertTrue(TMP_DIR_FOR_TESTS in output)
|
||||
|
||||
def test_list_directory(self):
|
||||
output = self.hook.list_directory(
|
||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertEqual(output, [SUB_DIR])
|
||||
|
||||
def test_create_and_delete_directory(self):
|
||||
new_dir_name = 'new_dir'
|
||||
self.hook.create_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(new_dir_name in output)
|
||||
self.hook.delete_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(new_dir_name not in output)
|
||||
|
||||
def test_create_and_delete_directories(self):
|
||||
base_dir = "base_dir"
|
||||
sub_dir = "sub_dir"
|
||||
new_dir_path = os.path.join(base_dir, sub_dir)
|
||||
self.hook.create_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(base_dir in output)
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
||||
self.assertTrue(sub_dir in output)
|
||||
self.hook.delete_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
||||
self.hook.delete_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(new_dir_path not in output)
|
||||
self.assertTrue(base_dir not in output)
|
||||
|
||||
def test_store_retrieve_and_delete_file(self):
|
||||
self.hook.store_file(
|
||||
remote_full_path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
||||
)
|
||||
output = self.hook.list_directory(
|
||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
|
||||
retrieved_file_name = 'retrieved.txt'
|
||||
self.hook.retrieve_file(
|
||||
remote_full_path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||
local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
|
||||
)
|
||||
self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
|
||||
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
|
||||
self.hook.delete_file(path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
||||
output = self.hook.list_directory(
|
||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertEqual(output, [SUB_DIR])
|
||||
|
||||
def test_get_mod_time(self):
|
||||
self.hook.store_file(
|
||||
remote_full_path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
||||
)
|
||||
output = self.hook.get_mod_time(path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
||||
self.assertEqual(len(output), 14)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_default(self, get_connection):
|
||||
connection = Connection(login='login', host='host')
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_enabled(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"no_host_key_check": true}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, True)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_disabled(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"no_host_key_check": false}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"no_host_key_check": "foo"}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_ignore(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"ignore_hostkey_verification": true}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, True)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_no_ignore(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"ignore_hostkey_verification": false}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@parameterized.expand([
|
||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
|
||||
(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
|
||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
|
||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
|
||||
])
|
||||
def test_path_exists(self, path, exists):
|
||||
result = self.hook.path_exists(path)
|
||||
self.assertEqual(result, exists)
|
||||
|
||||
@parameterized.expand([
|
||||
("test/path/file.bin", None, None, True),
|
||||
("test/path/file.bin", "test", None, True),
|
||||
("test/path/file.bin", "test/", None, True),
|
||||
("test/path/file.bin", None, "bin", True),
|
||||
("test/path/file.bin", "test", "bin", True),
|
||||
("test/path/file.bin", "test/", "file.bin", True),
|
||||
("test/path/file.bin", None, "file.bin", True),
|
||||
("test/path/file.bin", "diff", None, False),
|
||||
("test/path/file.bin", "test//", None, False),
|
||||
("test/path/file.bin", None, ".txt", False),
|
||||
("test/path/file.bin", "diff", ".txt", False),
|
||||
])
|
||||
def test_path_match(self, path, prefix, delimiter, match):
|
||||
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
|
||||
self.assertEqual(result, match)
|
||||
|
||||
def test_get_tree_map(self):
|
||||
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
files, dirs, unknowns = tree_map
|
||||
|
||||
self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
|
||||
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
|
||||
self.assertEqual(unknowns, [])
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
|
||||
self.update_connection(self.old_login)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
class TestMovingSFTPHookToCore(TestCase):
|
||||
def test_move_sftp_hook_to_core(self):
|
||||
with self.assertWarns(DeprecationWarning) as warn:
|
||||
# Reload to see deprecation warning each time
|
||||
importlib.import_module(OLD_PATH)
|
||||
self.assertEqual(WARNING_MESSAGE, str(warn.warning))
|
||||
|
|
|
@ -17,436 +17,17 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from base64 import b64encode
|
||||
from unittest import mock
|
||||
import importlib
|
||||
from unittest import TestCase
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator
|
||||
from airflow.contrib.operators.ssh_operator import SSHOperator
|
||||
from airflow.models import DAG, TaskInstance
|
||||
from airflow.utils import timezone
|
||||
from airflow.utils.timezone import datetime
|
||||
from tests.test_utils.config import conf_vars
|
||||
|
||||
TEST_DAG_ID = 'unit_tests_sftp_op'
|
||||
DEFAULT_DATE = datetime(2017, 1, 1)
|
||||
TEST_CONN_ID = "conn_id_for_testing"
|
||||
OLD_PATH = "airflow.contrib.operators.sftp_operator"
|
||||
NEW_PATH = "airflow.providers.sftp.operators.sftp_operator"
|
||||
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
|
||||
|
||||
|
||||
class TestSFTPOperator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||
|
||||
hook = SSHHook(ssh_conn_id='ssh_default')
|
||||
hook.no_host_key_check = True
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE,
|
||||
}
|
||||
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
|
||||
dag.schedule_interval = '@once'
|
||||
self.hook = hook
|
||||
self.dag = dag
|
||||
self.test_dir = "/tmp"
|
||||
self.test_local_dir = "/tmp/tmp2"
|
||||
self.test_remote_dir = "/tmp/tmp1"
|
||||
self.test_local_filename = 'test_local_file'
|
||||
self.test_remote_filename = 'test_remote_file'
|
||||
self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
|
||||
self.test_local_filename)
|
||||
# Local Filepath with Intermediate Directory
|
||||
self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
|
||||
self.test_local_filename)
|
||||
self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
|
||||
self.test_remote_filename)
|
||||
# Remote Filepath with Intermediate Directory
|
||||
self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
|
||||
self.test_remote_filename)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_pickle_file_transfer_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# put test file to remote
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
create_intermediate_dirs=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# check the remote file content
|
||||
check_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="cat {0}".format(self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(check_file_task)
|
||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
self.assertEqual(
|
||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||
test_local_file_content)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_no_intermediate_dir_error_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# Try to put test file to remote
|
||||
# This should raise an error with "No such file" as the directory
|
||||
# does not exist
|
||||
with self.assertRaises(Exception) as error:
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath_int_dir,
|
||||
operation=SFTPOperation.PUT,
|
||||
create_intermediate_dirs=False,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
self.assertIn('No such file', str(error.exception))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_with_intermediate_dir_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# put test file to remote
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath_int_dir,
|
||||
operation=SFTPOperation.PUT,
|
||||
create_intermediate_dirs=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# check the remote file content
|
||||
check_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="cat {0}".format(self.test_remote_filepath_int_dir),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(check_file_task)
|
||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
self.assertEqual(
|
||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||
test_local_file_content)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
||||
def test_json_file_transfer_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# put test file to remote
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# check the remote file content
|
||||
check_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="cat {0}".format(self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(check_file_task)
|
||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
self.assertEqual(
|
||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||
b64encode(test_local_file_content).decode('utf-8'))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_pickle_file_transfer_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# get remote file to local
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# test the received content
|
||||
content_received = None
|
||||
with open(self.test_local_filepath, 'r') as file:
|
||||
content_received = file.read()
|
||||
self.assertEqual(content_received.strip(), test_remote_file_content)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
||||
def test_json_file_transfer_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# get remote file to local
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# test the received content
|
||||
content_received = None
|
||||
with open(self.test_local_filepath, 'r') as file:
|
||||
content_received = file.read()
|
||||
self.assertEqual(content_received.strip(),
|
||||
test_remote_file_content.encode('utf-8').decode('utf-8'))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_no_intermediate_dir_error_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# Try to GET test file from remote
|
||||
# This should raise an error with "No such file" as the directory
|
||||
# does not exist
|
||||
with self.assertRaises(Exception) as error:
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath_int_dir,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
self.assertIn('No such file', str(error.exception))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_with_intermediate_dir_error_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# get remote file to local
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath_int_dir,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
create_intermediate_dirs=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# test the received content
|
||||
content_received = None
|
||||
with open(self.test_local_filepath_int_dir, 'r') as file:
|
||||
content_received = file.read()
|
||||
self.assertEqual(content_received.strip(), test_remote_file_content)
|
||||
|
||||
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
|
||||
def test_arg_checking(self):
|
||||
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
|
||||
with self.assertRaisesRegex(AirflowException,
|
||||
"Cannot operate without ssh_hook or ssh_conn_id."):
|
||||
task_0 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
task_0.execute(None)
|
||||
|
||||
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
|
||||
task_1 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
|
||||
ssh_conn_id=TEST_CONN_ID,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
try:
|
||||
task_1.execute(None)
|
||||
except Exception:
|
||||
pass
|
||||
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
||||
|
||||
task_2 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
try:
|
||||
task_2.execute(None)
|
||||
except Exception:
|
||||
pass
|
||||
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
||||
|
||||
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
|
||||
task_3 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
ssh_conn_id=TEST_CONN_ID,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
try:
|
||||
task_3.execute(None)
|
||||
except Exception:
|
||||
pass
|
||||
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
|
||||
|
||||
def delete_local_resource(self):
|
||||
if os.path.exists(self.test_local_filepath):
|
||||
os.remove(self.test_local_filepath)
|
||||
if os.path.exists(self.test_local_filepath_int_dir):
|
||||
os.remove(self.test_local_filepath_int_dir)
|
||||
if os.path.exists(self.test_local_dir):
|
||||
os.rmdir(self.test_local_dir)
|
||||
|
||||
def delete_remote_resource(self):
|
||||
if os.path.exists(self.test_remote_filepath):
|
||||
# check the remote file content
|
||||
remove_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="rm {0}".format(self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(remove_file_task)
|
||||
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
if os.path.exists(self.test_remote_filepath_int_dir):
|
||||
os.remove(self.test_remote_filepath_int_dir)
|
||||
if os.path.exists(self.test_remote_dir):
|
||||
os.rmdir(self.test_remote_dir)
|
||||
|
||||
def tearDown(self):
|
||||
self.delete_local_resource()
|
||||
self.delete_remote_resource()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
class TestMovingSFTPHookToCore(TestCase):
|
||||
def test_move_sftp_operator_to_core(self):
|
||||
with self.assertWarns(DeprecationWarning) as warn:
|
||||
# Reload to see deprecation warning each time
|
||||
importlib.import_module(OLD_PATH)
|
||||
self.assertEqual(WARNING_MESSAGE, str(warn.warning))
|
||||
|
|
|
@ -17,61 +17,17 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import importlib
|
||||
from unittest import TestCase
|
||||
|
||||
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
|
||||
|
||||
from airflow.contrib.sensors.sftp_sensor import SFTPSensor
|
||||
OLD_PATH = "airflow.contrib.sensors.sftp_sensor"
|
||||
NEW_PATH = "airflow.providers.sftp.sensors.sftp_sensor"
|
||||
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
|
||||
|
||||
|
||||
class TestSFTPSensor(unittest.TestCase):
|
||||
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
|
||||
def test_file_present(self, sftp_hook_mock):
|
||||
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
context = {
|
||||
'ds': '1970-01-01'
|
||||
}
|
||||
output = sftp_sensor.poke(context)
|
||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||
'/path/to/file/1970-01-01.txt')
|
||||
self.assertTrue(output)
|
||||
|
||||
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
|
||||
def test_file_absent(self, sftp_hook_mock):
|
||||
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
||||
SFTP_NO_SUCH_FILE, 'File missing')
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
context = {
|
||||
'ds': '1970-01-01'
|
||||
}
|
||||
output = sftp_sensor.poke(context)
|
||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||
'/path/to/file/1970-01-01.txt')
|
||||
self.assertFalse(output)
|
||||
|
||||
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
|
||||
def test_sftp_failure(self, sftp_hook_mock):
|
||||
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
||||
SFTP_FAILURE, 'SFTP failure')
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
context = {
|
||||
'ds': '1970-01-01'
|
||||
}
|
||||
with self.assertRaises(OSError):
|
||||
sftp_sensor.poke(context)
|
||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||
'/path/to/file/1970-01-01.txt')
|
||||
|
||||
def test_hook_not_created_during_init(self):
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
self.assertIsNone(sftp_sensor.hook)
|
||||
class TestMovingSFTPHookToCore(TestCase):
|
||||
def test_move_sftp_sensor_to_core(self):
|
||||
with self.assertWarns(DeprecationWarning) as warn:
|
||||
# Reload to see deprecation warning each time
|
||||
importlib.import_module(OLD_PATH)
|
||||
self.assertEqual(WARNING_MESSAGE, str(warn.warning))
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,247 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import pysftp
|
||||
from parameterized import parameterized
|
||||
|
||||
from airflow.models import Connection
|
||||
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||
from airflow.utils.db import provide_session
|
||||
|
||||
TMP_PATH = '/tmp'
|
||||
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
|
||||
SUB_DIR = "sub_dir"
|
||||
TMP_FILE_FOR_TESTS = 'test_file.txt'
|
||||
|
||||
SFTP_CONNECTION_USER = "root"
|
||||
|
||||
|
||||
class TestSFTPHook(unittest.TestCase):
|
||||
|
||||
@provide_session
|
||||
def update_connection(self, login, session=None):
|
||||
connection = (session.query(Connection).
|
||||
filter(Connection.conn_id == "sftp_default")
|
||||
.first())
|
||||
old_login = connection.login
|
||||
connection.login = login
|
||||
session.commit()
|
||||
return old_login
|
||||
|
||||
def setUp(self):
|
||||
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
|
||||
self.hook = SFTPHook()
|
||||
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
|
||||
|
||||
with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
|
||||
file.write('Test file')
|
||||
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
|
||||
file.write('Test file')
|
||||
|
||||
def test_get_conn(self):
|
||||
output = self.hook.get_conn()
|
||||
self.assertEqual(type(output), pysftp.Connection)
|
||||
|
||||
def test_close_conn(self):
|
||||
self.hook.conn = self.hook.get_conn()
|
||||
self.assertTrue(self.hook.conn is not None)
|
||||
self.hook.close_conn()
|
||||
self.assertTrue(self.hook.conn is None)
|
||||
|
||||
def test_describe_directory(self):
|
||||
output = self.hook.describe_directory(TMP_PATH)
|
||||
self.assertTrue(TMP_DIR_FOR_TESTS in output)
|
||||
|
||||
def test_list_directory(self):
|
||||
output = self.hook.list_directory(
|
||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertEqual(output, [SUB_DIR])
|
||||
|
||||
def test_create_and_delete_directory(self):
|
||||
new_dir_name = 'new_dir'
|
||||
self.hook.create_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(new_dir_name in output)
|
||||
self.hook.delete_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(new_dir_name not in output)
|
||||
|
||||
def test_create_and_delete_directories(self):
|
||||
base_dir = "base_dir"
|
||||
sub_dir = "sub_dir"
|
||||
new_dir_path = os.path.join(base_dir, sub_dir)
|
||||
self.hook.create_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(base_dir in output)
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
||||
self.assertTrue(sub_dir in output)
|
||||
self.hook.delete_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
||||
self.hook.delete_directory(os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
||||
output = self.hook.describe_directory(
|
||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertTrue(new_dir_path not in output)
|
||||
self.assertTrue(base_dir not in output)
|
||||
|
||||
def test_store_retrieve_and_delete_file(self):
|
||||
self.hook.store_file(
|
||||
remote_full_path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
||||
)
|
||||
output = self.hook.list_directory(
|
||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
|
||||
retrieved_file_name = 'retrieved.txt'
|
||||
self.hook.retrieve_file(
|
||||
remote_full_path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||
local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
|
||||
)
|
||||
self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
|
||||
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
|
||||
self.hook.delete_file(path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
||||
output = self.hook.list_directory(
|
||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
self.assertEqual(output, [SUB_DIR])
|
||||
|
||||
def test_get_mod_time(self):
|
||||
self.hook.store_file(
|
||||
remote_full_path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
||||
)
|
||||
output = self.hook.get_mod_time(path=os.path.join(
|
||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
||||
self.assertEqual(len(output), 14)
|
||||
|
||||
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_default(self, get_connection):
|
||||
connection = Connection(login='login', host='host')
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_enabled(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"no_host_key_check": true}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, True)
|
||||
|
||||
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_disabled(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"no_host_key_check": false}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"no_host_key_check": "foo"}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_ignore(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"ignore_hostkey_verification": true}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, True)
|
||||
|
||||
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||
def test_no_host_key_check_no_ignore(self, get_connection):
|
||||
connection = Connection(
|
||||
login='login', host='host',
|
||||
extra='{"ignore_hostkey_verification": false}')
|
||||
|
||||
get_connection.return_value = connection
|
||||
hook = SFTPHook()
|
||||
self.assertEqual(hook.no_host_key_check, False)
|
||||
|
||||
@parameterized.expand([
|
||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
|
||||
(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
|
||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
|
||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
|
||||
])
|
||||
def test_path_exists(self, path, exists):
|
||||
result = self.hook.path_exists(path)
|
||||
self.assertEqual(result, exists)
|
||||
|
||||
@parameterized.expand([
|
||||
("test/path/file.bin", None, None, True),
|
||||
("test/path/file.bin", "test", None, True),
|
||||
("test/path/file.bin", "test/", None, True),
|
||||
("test/path/file.bin", None, "bin", True),
|
||||
("test/path/file.bin", "test", "bin", True),
|
||||
("test/path/file.bin", "test/", "file.bin", True),
|
||||
("test/path/file.bin", None, "file.bin", True),
|
||||
("test/path/file.bin", "diff", None, False),
|
||||
("test/path/file.bin", "test//", None, False),
|
||||
("test/path/file.bin", None, ".txt", False),
|
||||
("test/path/file.bin", "diff", ".txt", False),
|
||||
])
|
||||
def test_path_match(self, path, prefix, delimiter, match):
|
||||
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
|
||||
self.assertEqual(result, match)
|
||||
|
||||
def test_get_tree_map(self):
|
||||
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
files, dirs, unknowns = tree_map
|
||||
|
||||
self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
|
||||
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
|
||||
self.assertEqual(unknowns, [])
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
|
||||
self.update_connection(self.old_login)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,452 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from base64 import b64encode
|
||||
from unittest import mock
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.operators.ssh_operator import SSHOperator
|
||||
from airflow.models import DAG, TaskInstance
|
||||
from airflow.providers.sftp.operators.sftp_operator import SFTPOperation, SFTPOperator
|
||||
from airflow.utils import timezone
|
||||
from airflow.utils.timezone import datetime
|
||||
from tests.test_utils.config import conf_vars
|
||||
|
||||
TEST_DAG_ID = 'unit_tests_sftp_op'
|
||||
DEFAULT_DATE = datetime(2017, 1, 1)
|
||||
TEST_CONN_ID = "conn_id_for_testing"
|
||||
|
||||
|
||||
class TestSFTPOperator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||
|
||||
hook = SSHHook(ssh_conn_id='ssh_default')
|
||||
hook.no_host_key_check = True
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE,
|
||||
}
|
||||
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
|
||||
dag.schedule_interval = '@once'
|
||||
self.hook = hook
|
||||
self.dag = dag
|
||||
self.test_dir = "/tmp"
|
||||
self.test_local_dir = "/tmp/tmp2"
|
||||
self.test_remote_dir = "/tmp/tmp1"
|
||||
self.test_local_filename = 'test_local_file'
|
||||
self.test_remote_filename = 'test_remote_file'
|
||||
self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
|
||||
self.test_local_filename)
|
||||
# Local Filepath with Intermediate Directory
|
||||
self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
|
||||
self.test_local_filename)
|
||||
self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
|
||||
self.test_remote_filename)
|
||||
# Remote Filepath with Intermediate Directory
|
||||
self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
|
||||
self.test_remote_filename)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_pickle_file_transfer_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# put test file to remote
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
create_intermediate_dirs=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# check the remote file content
|
||||
check_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="cat {0}".format(self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(check_file_task)
|
||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
self.assertEqual(
|
||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||
test_local_file_content)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_no_intermediate_dir_error_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# Try to put test file to remote
|
||||
# This should raise an error with "No such file" as the directory
|
||||
# does not exist
|
||||
with self.assertRaises(Exception) as error:
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath_int_dir,
|
||||
operation=SFTPOperation.PUT,
|
||||
create_intermediate_dirs=False,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
self.assertIn('No such file', str(error.exception))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_with_intermediate_dir_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# put test file to remote
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath_int_dir,
|
||||
operation=SFTPOperation.PUT,
|
||||
create_intermediate_dirs=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# check the remote file content
|
||||
check_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="cat {0}".format(self.test_remote_filepath_int_dir),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(check_file_task)
|
||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
self.assertEqual(
|
||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||
test_local_file_content)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
||||
def test_json_file_transfer_put(self):
|
||||
test_local_file_content = \
|
||||
b"This is local file content \n which is multiline " \
|
||||
b"continuing....with other character\nanother line here \n this is last line"
|
||||
# create a test file locally
|
||||
with open(self.test_local_filepath, 'wb') as file:
|
||||
file.write(test_local_file_content)
|
||||
|
||||
# put test file to remote
|
||||
put_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(put_test_task)
|
||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# check the remote file content
|
||||
check_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="cat {0}".format(self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(check_file_task)
|
||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
self.assertEqual(
|
||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||
b64encode(test_local_file_content).decode('utf-8'))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_pickle_file_transfer_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# get remote file to local
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# test the received content
|
||||
content_received = None
|
||||
with open(self.test_local_filepath, 'r') as file:
|
||||
content_received = file.read()
|
||||
self.assertEqual(content_received.strip(), test_remote_file_content)
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
||||
def test_json_file_transfer_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# get remote file to local
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# test the received content
|
||||
content_received = None
|
||||
with open(self.test_local_filepath, 'r') as file:
|
||||
content_received = file.read()
|
||||
self.assertEqual(content_received.strip(),
|
||||
test_remote_file_content.encode('utf-8').decode('utf-8'))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_no_intermediate_dir_error_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# Try to GET test file from remote
|
||||
# This should raise an error with "No such file" as the directory
|
||||
# does not exist
|
||||
with self.assertRaises(Exception) as error:
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath_int_dir,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
self.assertIn('No such file', str(error.exception))
|
||||
|
||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||
def test_file_transfer_with_intermediate_dir_error_get(self):
|
||||
test_remote_file_content = \
|
||||
"This is remote file content \n which is also multiline " \
|
||||
"another line here \n this is last line. EOF"
|
||||
|
||||
# create a test file remotely
|
||||
create_file_task = SSHOperator(
|
||||
task_id="test_create_file",
|
||||
ssh_hook=self.hook,
|
||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||
self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(create_file_task)
|
||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||
ti1.run()
|
||||
|
||||
# get remote file to local
|
||||
get_test_task = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
local_filepath=self.test_local_filepath_int_dir,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.GET,
|
||||
create_intermediate_dirs=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(get_test_task)
|
||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||
ti2.run()
|
||||
|
||||
# test the received content
|
||||
content_received = None
|
||||
with open(self.test_local_filepath_int_dir, 'r') as file:
|
||||
content_received = file.read()
|
||||
self.assertEqual(content_received.strip(), test_remote_file_content)
|
||||
|
||||
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
|
||||
def test_arg_checking(self):
|
||||
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
|
||||
with self.assertRaisesRegex(AirflowException,
|
||||
"Cannot operate without ssh_hook or ssh_conn_id."):
|
||||
task_0 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
task_0.execute(None)
|
||||
|
||||
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
|
||||
task_1 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
|
||||
ssh_conn_id=TEST_CONN_ID,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
try:
|
||||
task_1.execute(None)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
||||
|
||||
task_2 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
try:
|
||||
task_2.execute(None)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
||||
|
||||
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
|
||||
task_3 = SFTPOperator(
|
||||
task_id="test_sftp",
|
||||
ssh_hook=self.hook,
|
||||
ssh_conn_id=TEST_CONN_ID,
|
||||
local_filepath=self.test_local_filepath,
|
||||
remote_filepath=self.test_remote_filepath,
|
||||
operation=SFTPOperation.PUT,
|
||||
dag=self.dag
|
||||
)
|
||||
try:
|
||||
task_3.execute(None)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
|
||||
|
||||
def delete_local_resource(self):
|
||||
if os.path.exists(self.test_local_filepath):
|
||||
os.remove(self.test_local_filepath)
|
||||
if os.path.exists(self.test_local_filepath_int_dir):
|
||||
os.remove(self.test_local_filepath_int_dir)
|
||||
if os.path.exists(self.test_local_dir):
|
||||
os.rmdir(self.test_local_dir)
|
||||
|
||||
def delete_remote_resource(self):
|
||||
if os.path.exists(self.test_remote_filepath):
|
||||
# check the remote file content
|
||||
remove_file_task = SSHOperator(
|
||||
task_id="test_check_file",
|
||||
ssh_hook=self.hook,
|
||||
command="rm {0}".format(self.test_remote_filepath),
|
||||
do_xcom_push=True,
|
||||
dag=self.dag
|
||||
)
|
||||
self.assertIsNotNone(remove_file_task)
|
||||
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
|
||||
ti3.run()
|
||||
if os.path.exists(self.test_remote_filepath_int_dir):
|
||||
os.remove(self.test_remote_filepath_int_dir)
|
||||
if os.path.exists(self.test_remote_dir):
|
||||
os.rmdir(self.test_remote_dir)
|
||||
|
||||
def tearDown(self):
|
||||
self.delete_local_resource()
|
||||
self.delete_remote_resource()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,77 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
|
||||
|
||||
from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor
|
||||
|
||||
|
||||
class TestSFTPSensor(unittest.TestCase):
|
||||
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
|
||||
def test_file_present(self, sftp_hook_mock):
|
||||
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
context = {
|
||||
'ds': '1970-01-01'
|
||||
}
|
||||
output = sftp_sensor.poke(context)
|
||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||
'/path/to/file/1970-01-01.txt')
|
||||
self.assertTrue(output)
|
||||
|
||||
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
|
||||
def test_file_absent(self, sftp_hook_mock):
|
||||
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
||||
SFTP_NO_SUCH_FILE, 'File missing')
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
context = {
|
||||
'ds': '1970-01-01'
|
||||
}
|
||||
output = sftp_sensor.poke(context)
|
||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||
'/path/to/file/1970-01-01.txt')
|
||||
self.assertFalse(output)
|
||||
|
||||
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
|
||||
def test_sftp_failure(self, sftp_hook_mock):
|
||||
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
||||
SFTP_FAILURE, 'SFTP failure')
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
context = {
|
||||
'ds': '1970-01-01'
|
||||
}
|
||||
with self.assertRaises(OSError):
|
||||
sftp_sensor.poke(context)
|
||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||
'/path/to/file/1970-01-01.txt')
|
||||
|
||||
def test_hook_not_created_during_init(self):
|
||||
sftp_sensor = SFTPSensor(
|
||||
task_id='unit_test',
|
||||
path='/path/to/file/1970-01-01.txt')
|
||||
self.assertIsNone(sftp_sensor.hook)
|
Загрузка…
Ссылка в новой задаче