[AIRFLOW-5807] Move SFTP from contrib to providers. (#6464)

* [AIRFLOW-5807] Move SFTP from contrib to core
This commit is contained in:
TobKed 2019-12-09 17:42:19 +01:00 коммит произвёл Jarek Potiuk
Родитель 25830a01a9
Коммит 69629a5a94
26 изменённых файлов: 1521 добавлений и 1180 удалений

Просмотреть файл

@ -16,274 +16,17 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""
This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.
"""
import datetime import warnings
import stat
from typing import Dict, List, Optional, Tuple
import pysftp # pylint: disable=unused-import
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook # noqa
from airflow.contrib.hooks.ssh_hook import SSHHook warnings.warn(
"This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.",
DeprecationWarning,
class SFTPHook(SSHHook): stacklevel=2,
""" )
This hook is inherited from SSH hook. Please refer to SSH hook for the input
arguments.
Interact with SFTP. Aims to be interchangeable with FTPHook.
:Pitfalls::
- In contrast with FTPHook describe_directory only returns size, type and
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
unique.
- retrieve_file and store_file only take a local full path and not a
buffer.
- If no mode is passed to create_directory it will be created with 777
permissions.
Errors that may occur throughout but should be handled downstream.
"""
def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
kwargs['ssh_conn_id'] = ftp_conn_id
super().__init__(*args, **kwargs)
self.conn = None
self.private_key_pass = None
# Fail for unverified hosts, unless this is explicitly allowed
self.no_host_key_check = False
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if conn.extra is not None:
extra_options = conn.extra_dejson
if 'private_key_pass' in extra_options:
self.private_key_pass = extra_options.get('private_key_pass', None)
# For backward compatibility
# TODO: remove in Airflow 2.1
import warnings
if 'ignore_hostkey_verification' in extra_options:
warnings.warn(
'Extra option `ignore_hostkey_verification` is deprecated.'
'Please use `no_host_key_check` instead.'
'This option will be removed in Airflow 2.1',
DeprecationWarning,
stacklevel=2,
)
self.no_host_key_check = str(
extra_options['ignore_hostkey_verification']
).lower() == 'true'
if 'no_host_key_check' in extra_options:
self.no_host_key_check = str(
extra_options['no_host_key_check']).lower() == 'true'
if 'private_key' in extra_options:
warnings.warn(
'Extra option `private_key` is deprecated.'
'Please use `key_file` instead.'
'This option will be removed in Airflow 2.1',
DeprecationWarning,
stacklevel=2,
)
self.key_file = extra_options.get('private_key')
def get_conn(self) -> pysftp.Connection:
"""
Returns an SFTP connection object
"""
if self.conn is None:
cnopts = pysftp.CnOpts()
if self.no_host_key_check:
cnopts.hostkeys = None
cnopts.compression = self.compress
conn_params = {
'host': self.remote_host,
'port': self.port,
'username': self.username,
'cnopts': cnopts
}
if self.password and self.password.strip():
conn_params['password'] = self.password
if self.key_file:
conn_params['private_key'] = self.key_file
if self.private_key_pass:
conn_params['private_key_pass'] = self.private_key_pass
self.conn = pysftp.Connection(**conn_params)
return self.conn
def close_conn(self) -> None:
"""
Closes the connection. An error will occur if the
connection wasnt ever opened.
"""
conn = self.conn
conn.close() # type: ignore
self.conn = None
def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
"""
Returns a dictionary of {filename: {attributes}} for all files
on the remote system (where the MLSD command is supported).
:param path: full path to the remote directory
:type path: str
"""
conn = self.get_conn()
flist = conn.listdir_attr(path)
files = {}
for f in flist:
modify = datetime.datetime.fromtimestamp(
f.st_mtime).strftime('%Y%m%d%H%M%S')
files[f.filename] = {
'size': f.st_size,
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
'modify': modify}
return files
def list_directory(self, path: str) -> List[str]:
"""
Returns a list of files on the remote system.
:param path: full path to the remote directory to list
:type path: str
"""
conn = self.get_conn()
files = conn.listdir(path)
return files
def create_directory(self, path: str, mode: int = 777) -> None:
"""
Creates a directory on the remote system.
:param path: full path to the remote directory to create
:type path: str
:param mode: int representation of octal mode for directory
"""
conn = self.get_conn()
conn.makedirs(path, mode)
def delete_directory(self, path: str) -> None:
"""
Deletes a directory on the remote system.
:param path: full path to the remote directory to delete
:type path: str
"""
conn = self.get_conn()
conn.rmdir(path)
def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
"""
Transfers the remote file to a local location.
If local_full_path is a string path, the file will be put
at that location
:param remote_full_path: full path to the remote file
:type remote_full_path: str
:param local_full_path: full path to the local file
:type local_full_path: str
"""
conn = self.get_conn()
self.log.info('Retrieving file from FTP: %s', remote_full_path)
conn.get(remote_full_path, local_full_path)
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
def store_file(self, remote_full_path: str, local_full_path: str) -> None:
"""
Transfers a local file to the remote location.
If local_full_path_or_buffer is a string path, the file will be read
from that location
:param remote_full_path: full path to the remote file
:type remote_full_path: str
:param local_full_path: full path to the local file
:type local_full_path: str
"""
conn = self.get_conn()
conn.put(local_full_path, remote_full_path)
def delete_file(self, path: str) -> None:
"""
Removes a file on the FTP Server
:param path: full path to the remote file
:type path: str
"""
conn = self.get_conn()
conn.remove(path)
def get_mod_time(self, path: str) -> str:
conn = self.get_conn()
ftp_mdtm = conn.stat(path).st_mtime
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
def path_exists(self, path: str) -> bool:
"""
Returns True if a remote entity exists
:param path: full path to the remote file or directory
:type path: str
"""
conn = self.get_conn()
return conn.exists(path)
@staticmethod
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
"""
Return True if given path starts with prefix (if set) and ends with delimiter (if set).
:param path: path to be checked
:type path: str
:param prefix: if set path will be checked is starting with prefix
:type prefix: str
:param delimiter: if set path will be checked is ending with suffix
:type delimiter: str
:return: bool
"""
if prefix is not None and not path.startswith(prefix):
return False
if delimiter is not None and not path.endswith(delimiter):
return False
return True
def get_tree_map(
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
) -> Tuple[List[str], List[str], List[str]]:
"""
Return tuple with recursive lists of files, directories and unknown paths from given path.
It is possible to filter results by giving prefix and/or delimiter parameters.
:param path: path from which tree will be built
:type path: str
:param prefix: if set paths will be added if start with prefix
:type prefix: str
:param delimiter: if set paths will be added if end with delimiter
:type delimiter: str
:return: tuple with list of files, dirs and unknown items
:rtype: Tuple[List[str], List[str], List[str]]
"""
conn = self.get_conn()
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
def append_matching_path_callback(list_):
return (
lambda item: list_.append(item)
if self._is_path_match(item, prefix, delimiter)
else None
)
conn.walktree(
remotepath=path,
fcallback=append_matching_path_callback(files),
dcallback=append_matching_path_callback(dirs),
ucallback=append_matching_path_callback(unknowns),
recurse=True,
)
return files, dirs, unknowns

Просмотреть файл

@ -16,165 +16,17 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os """
This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.
"""
from airflow.contrib.hooks.ssh_hook import SSHHook import warnings
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
# pylint: disable=unused-import
from airflow.providers.sftp.operators.sftp_operator import SFTPOperator # noqa
class SFTPOperation: warnings.warn(
PUT = 'put' "This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.",
GET = 'get' DeprecationWarning,
stacklevel=2,
)
class SFTPOperator(BaseOperator):
"""
SFTPOperator for transferring files from remote host to local or vice a versa.
This operator uses ssh_hook to open sftp transport channel that serve as basis
for file transfer.
:param ssh_hook: predefined ssh_hook to use for remote execution.
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
:type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
:param ssh_conn_id: connection id from airflow Connections.
`ssh_conn_id` will be ignored if `ssh_hook` is provided.
:type ssh_conn_id: str
:param remote_host: remote host to connect (templated)
Nullable. If provided, it will replace the `remote_host` which was
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
:type remote_host: str
:param local_filepath: local file path to get or put. (templated)
:type local_filepath: str
:param remote_filepath: remote file path to get or put. (templated)
:type remote_filepath: str
:param operation: specify operation 'get' or 'put', defaults to put
:type operation: str
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
:type confirm: bool
:param create_intermediate_dirs: create missing intermediate directories when
copying from remote to local and vice-versa. Default is False.
Example: The following task would copy ``file.txt`` to the remote host
at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
don't exist. If the parameter is not passed it would error as the directory
does not exist. ::
put_file = SFTPOperator(
task_id="test_sftp",
ssh_conn_id="ssh_default",
local_filepath="/tmp/file.txt",
remote_filepath="/tmp/tmp1/tmp2/file.txt",
operation="put",
create_intermediate_dirs=True,
dag=dag
)
:type create_intermediate_dirs: bool
"""
template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
@apply_defaults
def __init__(self,
ssh_hook=None,
ssh_conn_id=None,
remote_host=None,
local_filepath=None,
remote_filepath=None,
operation=SFTPOperation.PUT,
confirm=True,
create_intermediate_dirs=False,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.ssh_hook = ssh_hook
self.ssh_conn_id = ssh_conn_id
self.remote_host = remote_host
self.local_filepath = local_filepath
self.remote_filepath = remote_filepath
self.operation = operation
self.confirm = confirm
self.create_intermediate_dirs = create_intermediate_dirs
if not (self.operation.lower() == SFTPOperation.GET or
self.operation.lower() == SFTPOperation.PUT):
raise TypeError("unsupported operation value {0}, expected {1} or {2}"
.format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
def execute(self, context):
file_msg = None
try:
if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
else:
self.log.info("ssh_hook is not provided or invalid. " +
"Trying ssh_conn_id to create SSHHook.")
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
if not self.ssh_hook:
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
if self.remote_host is not None:
self.log.info("remote_host is provided explicitly. " +
"It will replace the remote_host which was defined " +
"in ssh_hook or predefined in connection of ssh_conn_id.")
self.ssh_hook.remote_host = self.remote_host
with self.ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
if self.operation.lower() == SFTPOperation.GET:
local_folder = os.path.dirname(self.local_filepath)
if self.create_intermediate_dirs:
# Create Intermediate Directories if it doesn't exist
try:
os.makedirs(local_folder)
except OSError:
if not os.path.isdir(local_folder):
raise
file_msg = "from {0} to {1}".format(self.remote_filepath,
self.local_filepath)
self.log.info("Starting to transfer %s", file_msg)
sftp_client.get(self.remote_filepath, self.local_filepath)
else:
remote_folder = os.path.dirname(self.remote_filepath)
if self.create_intermediate_dirs:
_make_intermediate_dirs(
sftp_client=sftp_client,
remote_directory=remote_folder,
)
file_msg = "from {0} to {1}".format(self.local_filepath,
self.remote_filepath)
self.log.info("Starting to transfer file %s", file_msg)
sftp_client.put(self.local_filepath,
self.remote_filepath,
confirm=self.confirm)
except Exception as e:
raise AirflowException("Error while transferring {0}, error: {1}"
.format(file_msg, str(e)))
return self.local_filepath
def _make_intermediate_dirs(sftp_client, remote_directory):
"""
Create all the intermediate directories in a remote host
:param sftp_client: A Paramiko SFTP client.
:param remote_directory: Absolute Path of the directory containing the file
:return:
"""
if remote_directory == '/':
sftp_client.chdir('/')
return
if remote_directory == '':
return
try:
sftp_client.chdir(remote_directory)
except OSError:
dirname, basename = os.path.split(remote_directory.rstrip('/'))
_make_intermediate_dirs(sftp_client, dirname)
sftp_client.mkdir(basename)
sftp_client.chdir(basename)
return

Просмотреть файл

@ -16,40 +16,17 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""
This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.
"""
from paramiko import SFTP_NO_SUCH_FILE import warnings
from airflow.contrib.hooks.sftp_hook import SFTPHook # pylint: disable=unused-import
from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor # noqa
from airflow.utils.decorators import apply_defaults
warnings.warn(
class SFTPSensor(BaseSensorOperator): "This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.",
""" DeprecationWarning,
Waits for a file or directory to be present on SFTP. stacklevel=2,
)
:param path: Remote file or directory path
:type path: str
:param sftp_conn_id: The connection to run the sensor against
:type sftp_conn_id: str
"""
template_fields = ('path',)
@apply_defaults
def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
super().__init__(*args, **kwargs)
self.path = path
self.hook = None
self.sftp_conn_id = sftp_conn_id
def poke(self, context):
self.hook = SFTPHook(self.sftp_conn_id)
self.log.info('Poking for %s', self.path)
try:
self.hook.get_mod_time(self.path)
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
raise e
return False
self.hook.close_conn()
return True

Просмотреть файл

@ -24,9 +24,9 @@ from tempfile import NamedTemporaryFile
from typing import Optional from typing import Optional
from airflow import AirflowException from airflow import AirflowException
from airflow.contrib.hooks.sftp_hook import SFTPHook
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
from airflow.models import BaseOperator from airflow.models import BaseOperator
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
from airflow.utils.decorators import apply_defaults from airflow.utils.decorators import apply_defaults
WILDCARD = "*" WILDCARD = "*"

Просмотреть файл

@ -23,9 +23,9 @@ from tempfile import NamedTemporaryFile
from typing import Optional, Union from typing import Optional, Union
from airflow import AirflowException from airflow import AirflowException
from airflow.contrib.hooks.sftp_hook import SFTPHook
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
from airflow.models import BaseOperator from airflow.models import BaseOperator
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
from airflow.utils.decorators import apply_defaults from airflow.utils.decorators import apply_defaults
WILDCARD = "*" WILDCARD = "*"

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains SFTP hook.
"""
import datetime
import stat
from typing import Dict, List, Optional, Tuple
import pysftp
from airflow.contrib.hooks.ssh_hook import SSHHook
class SFTPHook(SSHHook):
"""
This hook is inherited from SSH hook. Please refer to SSH hook for the input
arguments.
Interact with SFTP. Aims to be interchangeable with FTPHook.
:Pitfalls::
- In contrast with FTPHook describe_directory only returns size, type and
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
unique.
- retrieve_file and store_file only take a local full path and not a
buffer.
- If no mode is passed to create_directory it will be created with 777
permissions.
Errors that may occur throughout but should be handled downstream.
"""
def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
kwargs['ssh_conn_id'] = ftp_conn_id
super().__init__(*args, **kwargs)
self.conn = None
self.private_key_pass = None
# Fail for unverified hosts, unless this is explicitly allowed
self.no_host_key_check = False
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if conn.extra is not None:
extra_options = conn.extra_dejson
if 'private_key_pass' in extra_options:
self.private_key_pass = extra_options.get('private_key_pass', None)
# For backward compatibility
# TODO: remove in Airflow 2.1
import warnings
if 'ignore_hostkey_verification' in extra_options:
warnings.warn(
'Extra option `ignore_hostkey_verification` is deprecated.'
'Please use `no_host_key_check` instead.'
'This option will be removed in Airflow 2.1',
DeprecationWarning,
stacklevel=2,
)
self.no_host_key_check = str(
extra_options['ignore_hostkey_verification']
).lower() == 'true'
if 'no_host_key_check' in extra_options:
self.no_host_key_check = str(
extra_options['no_host_key_check']).lower() == 'true'
if 'private_key' in extra_options:
warnings.warn(
'Extra option `private_key` is deprecated.'
'Please use `key_file` instead.'
'This option will be removed in Airflow 2.1',
DeprecationWarning,
stacklevel=2,
)
self.key_file = extra_options.get('private_key')
def get_conn(self) -> pysftp.Connection:
"""
Returns an SFTP connection object
"""
if self.conn is None:
cnopts = pysftp.CnOpts()
if self.no_host_key_check:
cnopts.hostkeys = None
cnopts.compression = self.compress
conn_params = {
'host': self.remote_host,
'port': self.port,
'username': self.username,
'cnopts': cnopts
}
if self.password and self.password.strip():
conn_params['password'] = self.password
if self.key_file:
conn_params['private_key'] = self.key_file
if self.private_key_pass:
conn_params['private_key_pass'] = self.private_key_pass
self.conn = pysftp.Connection(**conn_params)
return self.conn
def close_conn(self) -> None:
"""
Closes the connection. An error will occur if the
connection wasnt ever opened.
"""
conn = self.conn
conn.close() # type: ignore
self.conn = None
def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
"""
Returns a dictionary of {filename: {attributes}} for all files
on the remote system (where the MLSD command is supported).
:param path: full path to the remote directory
:type path: str
"""
conn = self.get_conn()
flist = conn.listdir_attr(path)
files = {}
for f in flist:
modify = datetime.datetime.fromtimestamp(
f.st_mtime).strftime('%Y%m%d%H%M%S')
files[f.filename] = {
'size': f.st_size,
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
'modify': modify}
return files
def list_directory(self, path: str) -> List[str]:
"""
Returns a list of files on the remote system.
:param path: full path to the remote directory to list
:type path: str
"""
conn = self.get_conn()
files = conn.listdir(path)
return files
def create_directory(self, path: str, mode: int = 777) -> None:
"""
Creates a directory on the remote system.
:param path: full path to the remote directory to create
:type path: str
:param mode: int representation of octal mode for directory
"""
conn = self.get_conn()
conn.makedirs(path, mode)
def delete_directory(self, path: str) -> None:
"""
Deletes a directory on the remote system.
:param path: full path to the remote directory to delete
:type path: str
"""
conn = self.get_conn()
conn.rmdir(path)
def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
"""
Transfers the remote file to a local location.
If local_full_path is a string path, the file will be put
at that location
:param remote_full_path: full path to the remote file
:type remote_full_path: str
:param local_full_path: full path to the local file
:type local_full_path: str
"""
conn = self.get_conn()
self.log.info('Retrieving file from FTP: %s', remote_full_path)
conn.get(remote_full_path, local_full_path)
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
def store_file(self, remote_full_path: str, local_full_path: str) -> None:
"""
Transfers a local file to the remote location.
If local_full_path_or_buffer is a string path, the file will be read
from that location
:param remote_full_path: full path to the remote file
:type remote_full_path: str
:param local_full_path: full path to the local file
:type local_full_path: str
"""
conn = self.get_conn()
conn.put(local_full_path, remote_full_path)
def delete_file(self, path: str) -> None:
"""
Removes a file on the FTP Server
:param path: full path to the remote file
:type path: str
"""
conn = self.get_conn()
conn.remove(path)
def get_mod_time(self, path: str) -> str:
"""
Returns modification time.
:param path: full path to the remote file
:type path: str
"""
conn = self.get_conn()
ftp_mdtm = conn.stat(path).st_mtime
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
def path_exists(self, path: str) -> bool:
"""
Returns True if a remote entity exists
:param path: full path to the remote file or directory
:type path: str
"""
conn = self.get_conn()
return conn.exists(path)
@staticmethod
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
"""
Return True if given path starts with prefix (if set) and ends with delimiter (if set).
:param path: path to be checked
:type path: str
:param prefix: if set path will be checked is starting with prefix
:type prefix: str
:param delimiter: if set path will be checked is ending with suffix
:type delimiter: str
:return: bool
"""
if prefix is not None and not path.startswith(prefix):
return False
if delimiter is not None and not path.endswith(delimiter):
return False
return True
def get_tree_map(
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
) -> Tuple[List[str], List[str], List[str]]:
"""
Return tuple with recursive lists of files, directories and unknown paths from given path.
It is possible to filter results by giving prefix and/or delimiter parameters.
:param path: path from which tree will be built
:type path: str
:param prefix: if set paths will be added if start with prefix
:type prefix: str
:param delimiter: if set paths will be added if end with delimiter
:type delimiter: str
:return: tuple with list of files, dirs and unknown items
:rtype: Tuple[List[str], List[str], List[str]]
"""
conn = self.get_conn()
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
def append_matching_path_callback(list_):
return (
lambda item: list_.append(item)
if self._is_path_match(item, prefix, delimiter)
else None
)
conn.walktree(
remotepath=path,
fcallback=append_matching_path_callback(files),
dcallback=append_matching_path_callback(dirs),
ucallback=append_matching_path_callback(unknowns),
recurse=True,
)
return files, dirs, unknowns

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains SFTP operator.
"""
import os
from airflow.contrib.hooks.ssh_hook import SSHHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
# pylint: disable=missing-docstring
class SFTPOperation:
PUT = 'put'
GET = 'get'
class SFTPOperator(BaseOperator):
"""
SFTPOperator for transferring files from remote host to local or vice a versa.
This operator uses ssh_hook to open sftp transport channel that serve as basis
for file transfer.
:param ssh_hook: predefined ssh_hook to use for remote execution.
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
:type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
:param ssh_conn_id: connection id from airflow Connections.
`ssh_conn_id` will be ignored if `ssh_hook` is provided.
:type ssh_conn_id: str
:param remote_host: remote host to connect (templated)
Nullable. If provided, it will replace the `remote_host` which was
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
:type remote_host: str
:param local_filepath: local file path to get or put. (templated)
:type local_filepath: str
:param remote_filepath: remote file path to get or put. (templated)
:type remote_filepath: str
:param operation: specify operation 'get' or 'put', defaults to put
:type operation: str
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
:type confirm: bool
:param create_intermediate_dirs: create missing intermediate directories when
copying from remote to local and vice-versa. Default is False.
Example: The following task would copy ``file.txt`` to the remote host
at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
don't exist. If the parameter is not passed it would error as the directory
does not exist. ::
put_file = SFTPOperator(
task_id="test_sftp",
ssh_conn_id="ssh_default",
local_filepath="/tmp/file.txt",
remote_filepath="/tmp/tmp1/tmp2/file.txt",
operation="put",
create_intermediate_dirs=True,
dag=dag
)
:type create_intermediate_dirs: bool
"""
template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
@apply_defaults
def __init__(self,
ssh_hook=None,
ssh_conn_id=None,
remote_host=None,
local_filepath=None,
remote_filepath=None,
operation=SFTPOperation.PUT,
confirm=True,
create_intermediate_dirs=False,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.ssh_hook = ssh_hook
self.ssh_conn_id = ssh_conn_id
self.remote_host = remote_host
self.local_filepath = local_filepath
self.remote_filepath = remote_filepath
self.operation = operation
self.confirm = confirm
self.create_intermediate_dirs = create_intermediate_dirs
if not (self.operation.lower() == SFTPOperation.GET or
self.operation.lower() == SFTPOperation.PUT):
raise TypeError("unsupported operation value {0}, expected {1} or {2}"
.format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
def execute(self, context):
file_msg = None
try:
if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
else:
self.log.info("ssh_hook is not provided or invalid. "
"Trying ssh_conn_id to create SSHHook.")
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
if not self.ssh_hook:
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
if self.remote_host is not None:
self.log.info("remote_host is provided explicitly. "
"It will replace the remote_host which was defined "
"in ssh_hook or predefined in connection of ssh_conn_id.")
self.ssh_hook.remote_host = self.remote_host
with self.ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
if self.operation.lower() == SFTPOperation.GET:
local_folder = os.path.dirname(self.local_filepath)
if self.create_intermediate_dirs:
# Create Intermediate Directories if it doesn't exist
try:
os.makedirs(local_folder)
except OSError:
if not os.path.isdir(local_folder):
raise
file_msg = "from {0} to {1}".format(self.remote_filepath,
self.local_filepath)
self.log.info("Starting to transfer %s", file_msg)
sftp_client.get(self.remote_filepath, self.local_filepath)
else:
remote_folder = os.path.dirname(self.remote_filepath)
if self.create_intermediate_dirs:
_make_intermediate_dirs(
sftp_client=sftp_client,
remote_directory=remote_folder,
)
file_msg = "from {0} to {1}".format(self.local_filepath,
self.remote_filepath)
self.log.info("Starting to transfer file %s", file_msg)
sftp_client.put(self.local_filepath,
self.remote_filepath,
confirm=self.confirm)
except Exception as e:
raise AirflowException("Error while transferring {0}, error: {1}"
.format(file_msg, str(e)))
return self.local_filepath
def _make_intermediate_dirs(sftp_client, remote_directory):
"""
Create all the intermediate directories in a remote host
:param sftp_client: A Paramiko SFTP client.
:param remote_directory: Absolute Path of the directory containing the file
:return:
"""
if remote_directory == '/':
sftp_client.chdir('/')
return
if remote_directory == '':
return
try:
sftp_client.chdir(remote_directory)
except OSError:
dirname, basename = os.path.split(remote_directory.rstrip('/'))
_make_intermediate_dirs(sftp_client, dirname)
sftp_client.mkdir(basename)
sftp_client.chdir(basename)
return

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains SFTP sensor.
"""
from paramiko import SFTP_NO_SUCH_FILE
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class SFTPSensor(BaseSensorOperator):
"""
Waits for a file or directory to be present on SFTP.
:param path: Remote file or directory path
:type path: str
:param sftp_conn_id: The connection to run the sensor against
:type sftp_conn_id: str
"""
template_fields = ('path',)
@apply_defaults
def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
super().__init__(*args, **kwargs)
self.path = path
self.hook = None
self.sftp_conn_id = sftp_conn_id
def poke(self, context):
self.hook = SFTPHook(self.sftp_conn_id)
self.log.info('Poking for %s', self.path)
try:
self.hook.get_mod_time(self.path)
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
raise e
return False
self.hook.close_conn()
return True

Просмотреть файл

@ -76,6 +76,8 @@ All operators are in the following packages:
airflow/providers/amazon/aws/sensors/index airflow/providers/amazon/aws/sensors/index
airflow/providers/apache/cassandra/sensors/index
airflow/providers/google/cloud/operators/index airflow/providers/google/cloud/operators/index
airflow/providers/google/cloud/sensors/index airflow/providers/google/cloud/sensors/index
@ -84,7 +86,9 @@ All operators are in the following packages:
airflow/providers/google/marketing_platform/sensors/index airflow/providers/google/marketing_platform/sensors/index
airflow/providers/apache/cassandra/sensors/index airflow/providers/sftp/operators/index
airflow/providers/sftp/sensors/index
Hooks Hooks
----- -----
@ -117,6 +121,8 @@ All hooks are in the following packages:
airflow/providers/apache/cassandra/hooks/index airflow/providers/apache/cassandra/hooks/index
airflow/providers/sftp/hooks/index
Executors Executors
--------- ---------
Executors are the mechanism by which task instances get run. All executors are Executors are the mechanism by which task instances get run. All executors are

Просмотреть файл

@ -233,6 +233,7 @@ exclude_patterns = [
'_api/airflow/providers/amazon/aws/example_dags', '_api/airflow/providers/amazon/aws/example_dags',
'_api/airflow/providers/apache/index.rst', '_api/airflow/providers/apache/index.rst',
'_api/airflow/providers/apache/cassandra/index.rst', '_api/airflow/providers/apache/cassandra/index.rst',
'_api/airflow/providers/sftp/index.rst',
'_api/enums/index.rst', '_api/enums/index.rst',
'_api/json_schema/index.rst', '_api/json_schema/index.rst',
'_api/base_serialization/index.rst', '_api/base_serialization/index.rst',

Просмотреть файл

@ -1246,9 +1246,9 @@ communication protocols or interface.
* - `SSH File Transfer Protocol (SFTP) <https://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/>`__ * - `SSH File Transfer Protocol (SFTP) <https://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/>`__
- -
- :mod:`airflow.contrib.hooks.sftp_hook` - :mod:`airflow.providers.sftp.hooks.sftp_hook`
- :mod:`airflow.contrib.operators.sftp_operator` - :mod:`airflow.providers.sftp.operators.sftp_operator`
- :mod:`airflow.contrib.sensors.sftp_sensor` - :mod:`airflow.providers.sftp.sensors.sftp_sensor`
* - `Secure Shell (SSH) <https://tools.ietf.org/html/rfc4251>`__ * - `Secure Shell (SSH) <https://tools.ietf.org/html/rfc4251>`__
- -

Просмотреть файл

@ -21,7 +21,6 @@
./airflow/contrib/hooks/qubole_check_hook.py ./airflow/contrib/hooks/qubole_check_hook.py
./airflow/contrib/hooks/sagemaker_hook.py ./airflow/contrib/hooks/sagemaker_hook.py
./airflow/contrib/hooks/segment_hook.py ./airflow/contrib/hooks/segment_hook.py
./airflow/contrib/hooks/sftp_hook.py
./airflow/contrib/hooks/slack_webhook_hook.py ./airflow/contrib/hooks/slack_webhook_hook.py
./airflow/contrib/hooks/snowflake_hook.py ./airflow/contrib/hooks/snowflake_hook.py
./airflow/contrib/hooks/spark_jdbc_hook.py ./airflow/contrib/hooks/spark_jdbc_hook.py
@ -65,7 +64,6 @@
./airflow/contrib/operators/sagemaker_transform_operator.py ./airflow/contrib/operators/sagemaker_transform_operator.py
./airflow/contrib/operators/sagemaker_tuning_operator.py ./airflow/contrib/operators/sagemaker_tuning_operator.py
./airflow/contrib/operators/segment_track_event_operator.py ./airflow/contrib/operators/segment_track_event_operator.py
./airflow/contrib/operators/sftp_operator.py
./airflow/contrib/operators/sftp_to_s3_operator.py ./airflow/contrib/operators/sftp_to_s3_operator.py
./airflow/contrib/operators/slack_webhook_operator.py ./airflow/contrib/operators/slack_webhook_operator.py
./airflow/contrib/operators/snowflake_operator.py ./airflow/contrib/operators/snowflake_operator.py
@ -100,7 +98,6 @@
./airflow/contrib/sensors/sagemaker_training_sensor.py ./airflow/contrib/sensors/sagemaker_training_sensor.py
./airflow/contrib/sensors/sagemaker_transform_sensor.py ./airflow/contrib/sensors/sagemaker_transform_sensor.py
./airflow/contrib/sensors/sagemaker_tuning_sensor.py ./airflow/contrib/sensors/sagemaker_tuning_sensor.py
./airflow/contrib/sensors/sftp_sensor.py
./airflow/contrib/sensors/wasb_sensor.py ./airflow/contrib/sensors/wasb_sensor.py
./airflow/contrib/sensors/weekday_sensor.py ./airflow/contrib/sensors/weekday_sensor.py
./airflow/hooks/dbapi_hook.py ./airflow/hooks/dbapi_hook.py

Просмотреть файл

@ -17,231 +17,17 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os import importlib
import shutil from unittest import TestCase
import unittest
from unittest import mock
import pysftp OLD_PATH = "airflow.contrib.hooks.sftp_hook"
from parameterized import parameterized NEW_PATH = "airflow.providers.sftp.hooks.sftp_hook"
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
from airflow.contrib.hooks.sftp_hook import SFTPHook
from airflow.models import Connection
from airflow.utils.db import provide_session
TMP_PATH = '/tmp'
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
SFTP_CONNECTION_USER = "root"
class TestSFTPHook(unittest.TestCase): class TestMovingSFTPHookToCore(TestCase):
def test_move_sftp_hook_to_core(self):
@provide_session with self.assertWarns(DeprecationWarning) as warn:
def update_connection(self, login, session=None): # Reload to see deprecation warning each time
connection = (session.query(Connection). importlib.import_module(OLD_PATH)
filter(Connection.conn_id == "sftp_default") self.assertEqual(WARNING_MESSAGE, str(warn.warning))
.first())
old_login = connection.login
connection.login = login
session.commit()
return old_login
def setUp(self):
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
self.hook = SFTPHook()
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
def test_get_conn(self):
output = self.hook.get_conn()
self.assertEqual(type(output), pysftp.Connection)
def test_close_conn(self):
self.hook.conn = self.hook.get_conn()
self.assertTrue(self.hook.conn is not None)
self.hook.close_conn()
self.assertTrue(self.hook.conn is None)
def test_describe_directory(self):
output = self.hook.describe_directory(TMP_PATH)
self.assertTrue(TMP_DIR_FOR_TESTS in output)
def test_list_directory(self):
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [SUB_DIR])
def test_create_and_delete_directory(self):
new_dir_name = 'new_dir'
self.hook.create_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(new_dir_name in output)
self.hook.delete_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(new_dir_name not in output)
def test_create_and_delete_directories(self):
base_dir = "base_dir"
sub_dir = "sub_dir"
new_dir_path = os.path.join(base_dir, sub_dir)
self.hook.create_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(base_dir in output)
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
self.assertTrue(sub_dir in output)
self.hook.delete_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
self.hook.delete_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(new_dir_path not in output)
self.assertTrue(base_dir not in output)
def test_store_retrieve_and_delete_file(self):
self.hook.store_file(
remote_full_path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
)
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
retrieved_file_name = 'retrieved.txt'
self.hook.retrieve_file(
remote_full_path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
)
self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
self.hook.delete_file(path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [SUB_DIR])
def test_get_mod_time(self):
self.hook.store_file(
remote_full_path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
)
output = self.hook.get_mod_time(path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
self.assertEqual(len(output), 14)
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_default(self, get_connection):
connection = Connection(login='login', host='host')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_enabled(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"no_host_key_check": true}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, True)
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_disabled(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"no_host_key_check": false}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"no_host_key_check": "foo"}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_ignore(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"ignore_hostkey_verification": true}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, True)
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_no_ignore(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"ignore_hostkey_verification": false}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@parameterized.expand([
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
])
def test_path_exists(self, path, exists):
result = self.hook.path_exists(path)
self.assertEqual(result, exists)
@parameterized.expand([
("test/path/file.bin", None, None, True),
("test/path/file.bin", "test", None, True),
("test/path/file.bin", "test/", None, True),
("test/path/file.bin", None, "bin", True),
("test/path/file.bin", "test", "bin", True),
("test/path/file.bin", "test/", "file.bin", True),
("test/path/file.bin", None, "file.bin", True),
("test/path/file.bin", "diff", None, False),
("test/path/file.bin", "test//", None, False),
("test/path/file.bin", None, ".txt", False),
("test/path/file.bin", "diff", ".txt", False),
])
def test_path_match(self, path, prefix, delimiter, match):
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
self.assertEqual(result, match)
def test_get_tree_map(self):
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
files, dirs, unknowns = tree_map
self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
self.assertEqual(unknowns, [])
def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
self.update_connection(self.old_login)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -17,436 +17,17 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os import importlib
import unittest from unittest import TestCase
from base64 import b64encode
from unittest import mock
from airflow import AirflowException OLD_PATH = "airflow.contrib.operators.sftp_operator"
from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator NEW_PATH = "airflow.providers.sftp.operators.sftp_operator"
from airflow.contrib.operators.ssh_operator import SSHOperator WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
from airflow.models import DAG, TaskInstance
from airflow.utils import timezone
from airflow.utils.timezone import datetime
from tests.test_utils.config import conf_vars
TEST_DAG_ID = 'unit_tests_sftp_op'
DEFAULT_DATE = datetime(2017, 1, 1)
TEST_CONN_ID = "conn_id_for_testing"
class TestSFTPOperator(unittest.TestCase): class TestMovingSFTPHookToCore(TestCase):
def setUp(self): def test_move_sftp_operator_to_core(self):
from airflow.contrib.hooks.ssh_hook import SSHHook with self.assertWarns(DeprecationWarning) as warn:
# Reload to see deprecation warning each time
hook = SSHHook(ssh_conn_id='ssh_default') importlib.import_module(OLD_PATH)
hook.no_host_key_check = True self.assertEqual(WARNING_MESSAGE, str(warn.warning))
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE,
}
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
dag.schedule_interval = '@once'
self.hook = hook
self.dag = dag
self.test_dir = "/tmp"
self.test_local_dir = "/tmp/tmp2"
self.test_remote_dir = "/tmp/tmp1"
self.test_local_filename = 'test_local_file'
self.test_remote_filename = 'test_remote_file'
self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
self.test_local_filename)
# Local Filepath with Intermediate Directory
self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
self.test_local_filename)
self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
self.test_remote_filename)
# Remote Filepath with Intermediate Directory
self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
self.test_remote_filename)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_pickle_file_transfer_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
create_intermediate_dirs=True,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
test_local_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_no_intermediate_dir_error_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# Try to put test file to remote
# This should raise an error with "No such file" as the directory
# does not exist
with self.assertRaises(Exception) as error:
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath_int_dir,
operation=SFTPOperation.PUT,
create_intermediate_dirs=False,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
self.assertIn('No such file', str(error.exception))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_with_intermediate_dir_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath_int_dir,
operation=SFTPOperation.PUT,
create_intermediate_dirs=True,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath_int_dir),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
test_local_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
def test_json_file_transfer_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
b64encode(test_local_file_content).decode('utf-8'))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_pickle_file_transfer_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# get remote file to local
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
# test the received content
content_received = None
with open(self.test_local_filepath, 'r') as file:
content_received = file.read()
self.assertEqual(content_received.strip(), test_remote_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
def test_json_file_transfer_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# get remote file to local
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
# test the received content
content_received = None
with open(self.test_local_filepath, 'r') as file:
content_received = file.read()
self.assertEqual(content_received.strip(),
test_remote_file_content.encode('utf-8').decode('utf-8'))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_no_intermediate_dir_error_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# Try to GET test file from remote
# This should raise an error with "No such file" as the directory
# does not exist
with self.assertRaises(Exception) as error:
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath_int_dir,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
self.assertIn('No such file', str(error.exception))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_with_intermediate_dir_error_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# get remote file to local
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath_int_dir,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
create_intermediate_dirs=True,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
# test the received content
content_received = None
with open(self.test_local_filepath_int_dir, 'r') as file:
content_received = file.read()
self.assertEqual(content_received.strip(), test_remote_file_content)
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
def test_arg_checking(self):
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp",
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
task_0.execute(None)
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
task_1 = SFTPOperator(
task_id="test_sftp",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_1.execute(None)
except Exception:
pass
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
task_2 = SFTPOperator(
task_id="test_sftp",
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_2.execute(None)
except Exception:
pass
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_3.execute(None)
except Exception:
pass
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
def delete_local_resource(self):
if os.path.exists(self.test_local_filepath):
os.remove(self.test_local_filepath)
if os.path.exists(self.test_local_filepath_int_dir):
os.remove(self.test_local_filepath_int_dir)
if os.path.exists(self.test_local_dir):
os.rmdir(self.test_local_dir)
def delete_remote_resource(self):
if os.path.exists(self.test_remote_filepath):
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="rm {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(remove_file_task)
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
ti3.run()
if os.path.exists(self.test_remote_filepath_int_dir):
os.remove(self.test_remote_filepath_int_dir)
if os.path.exists(self.test_remote_dir):
os.rmdir(self.test_remote_dir)
def tearDown(self):
self.delete_local_resource()
self.delete_remote_resource()
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -17,61 +17,17 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import unittest import importlib
from unittest.mock import patch from unittest import TestCase
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE OLD_PATH = "airflow.contrib.sensors.sftp_sensor"
NEW_PATH = "airflow.providers.sftp.sensors.sftp_sensor"
from airflow.contrib.sensors.sftp_sensor import SFTPSensor WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
class TestSFTPSensor(unittest.TestCase): class TestMovingSFTPHookToCore(TestCase):
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook') def test_move_sftp_sensor_to_core(self):
def test_file_present(self, sftp_hook_mock): with self.assertWarns(DeprecationWarning) as warn:
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' # Reload to see deprecation warning each time
sftp_sensor = SFTPSensor( importlib.import_module(OLD_PATH)
task_id='unit_test', self.assertEqual(WARNING_MESSAGE, str(warn.warning))
path='/path/to/file/1970-01-01.txt')
context = {
'ds': '1970-01-01'
}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt')
self.assertTrue(output)
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
def test_file_absent(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
SFTP_NO_SUCH_FILE, 'File missing')
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
context = {
'ds': '1970-01-01'
}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt')
self.assertFalse(output)
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
def test_sftp_failure(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
SFTP_FAILURE, 'SFTP failure')
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
context = {
'ds': '1970-01-01'
}
with self.assertRaises(OSError):
sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt')
def test_hook_not_created_during_init(self):
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
self.assertIsNone(sftp_sensor.hook)

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,247 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import shutil
import unittest
from unittest import mock
import pysftp
from parameterized import parameterized
from airflow.models import Connection
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
from airflow.utils.db import provide_session
TMP_PATH = '/tmp'
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
SFTP_CONNECTION_USER = "root"
class TestSFTPHook(unittest.TestCase):
@provide_session
def update_connection(self, login, session=None):
connection = (session.query(Connection).
filter(Connection.conn_id == "sftp_default")
.first())
old_login = connection.login
connection.login = login
session.commit()
return old_login
def setUp(self):
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
self.hook = SFTPHook()
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
def test_get_conn(self):
output = self.hook.get_conn()
self.assertEqual(type(output), pysftp.Connection)
def test_close_conn(self):
self.hook.conn = self.hook.get_conn()
self.assertTrue(self.hook.conn is not None)
self.hook.close_conn()
self.assertTrue(self.hook.conn is None)
def test_describe_directory(self):
output = self.hook.describe_directory(TMP_PATH)
self.assertTrue(TMP_DIR_FOR_TESTS in output)
def test_list_directory(self):
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [SUB_DIR])
def test_create_and_delete_directory(self):
new_dir_name = 'new_dir'
self.hook.create_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(new_dir_name in output)
self.hook.delete_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(new_dir_name not in output)
def test_create_and_delete_directories(self):
base_dir = "base_dir"
sub_dir = "sub_dir"
new_dir_path = os.path.join(base_dir, sub_dir)
self.hook.create_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(base_dir in output)
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
self.assertTrue(sub_dir in output)
self.hook.delete_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
self.hook.delete_directory(os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
output = self.hook.describe_directory(
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertTrue(new_dir_path not in output)
self.assertTrue(base_dir not in output)
def test_store_retrieve_and_delete_file(self):
self.hook.store_file(
remote_full_path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
)
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
retrieved_file_name = 'retrieved.txt'
self.hook.retrieve_file(
remote_full_path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
)
self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
self.hook.delete_file(path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
self.assertEqual(output, [SUB_DIR])
def test_get_mod_time(self):
self.hook.store_file(
remote_full_path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
)
output = self.hook.get_mod_time(path=os.path.join(
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
self.assertEqual(len(output), 14)
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_default(self, get_connection):
connection = Connection(login='login', host='host')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_enabled(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"no_host_key_check": true}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, True)
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_disabled(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"no_host_key_check": false}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"no_host_key_check": "foo"}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_ignore(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"ignore_hostkey_verification": true}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, True)
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
def test_no_host_key_check_no_ignore(self, get_connection):
connection = Connection(
login='login', host='host',
extra='{"ignore_hostkey_verification": false}')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
@parameterized.expand([
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
])
def test_path_exists(self, path, exists):
result = self.hook.path_exists(path)
self.assertEqual(result, exists)
@parameterized.expand([
("test/path/file.bin", None, None, True),
("test/path/file.bin", "test", None, True),
("test/path/file.bin", "test/", None, True),
("test/path/file.bin", None, "bin", True),
("test/path/file.bin", "test", "bin", True),
("test/path/file.bin", "test/", "file.bin", True),
("test/path/file.bin", None, "file.bin", True),
("test/path/file.bin", "diff", None, False),
("test/path/file.bin", "test//", None, False),
("test/path/file.bin", None, ".txt", False),
("test/path/file.bin", "diff", ".txt", False),
])
def test_path_match(self, path, prefix, delimiter, match):
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
self.assertEqual(result, match)
def test_get_tree_map(self):
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
files, dirs, unknowns = tree_map
self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
self.assertEqual(unknowns, [])
def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
self.update_connection(self.old_login)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,452 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import unittest
from base64 import b64encode
from unittest import mock
from airflow import AirflowException
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.models import DAG, TaskInstance
from airflow.providers.sftp.operators.sftp_operator import SFTPOperation, SFTPOperator
from airflow.utils import timezone
from airflow.utils.timezone import datetime
from tests.test_utils.config import conf_vars
TEST_DAG_ID = 'unit_tests_sftp_op'
DEFAULT_DATE = datetime(2017, 1, 1)
TEST_CONN_ID = "conn_id_for_testing"
class TestSFTPOperator(unittest.TestCase):
def setUp(self):
from airflow.contrib.hooks.ssh_hook import SSHHook
hook = SSHHook(ssh_conn_id='ssh_default')
hook.no_host_key_check = True
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE,
}
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
dag.schedule_interval = '@once'
self.hook = hook
self.dag = dag
self.test_dir = "/tmp"
self.test_local_dir = "/tmp/tmp2"
self.test_remote_dir = "/tmp/tmp1"
self.test_local_filename = 'test_local_file'
self.test_remote_filename = 'test_remote_file'
self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
self.test_local_filename)
# Local Filepath with Intermediate Directory
self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
self.test_local_filename)
self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
self.test_remote_filename)
# Remote Filepath with Intermediate Directory
self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
self.test_remote_filename)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_pickle_file_transfer_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
create_intermediate_dirs=True,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
test_local_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_no_intermediate_dir_error_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# Try to put test file to remote
# This should raise an error with "No such file" as the directory
# does not exist
with self.assertRaises(Exception) as error:
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath_int_dir,
operation=SFTPOperation.PUT,
create_intermediate_dirs=False,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
self.assertIn('No such file', str(error.exception))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_with_intermediate_dir_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath_int_dir,
operation=SFTPOperation.PUT,
create_intermediate_dirs=True,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath_int_dir),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
test_local_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
def test_json_file_transfer_put(self):
test_local_file_content = \
b"This is local file content \n which is multiline " \
b"continuing....with other character\nanother line here \n this is last line"
# create a test file locally
with open(self.test_local_filepath, 'wb') as file:
file.write(test_local_file_content)
# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
b64encode(test_local_file_content).decode('utf-8'))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_pickle_file_transfer_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# get remote file to local
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
# test the received content
content_received = None
with open(self.test_local_filepath, 'r') as file:
content_received = file.read()
self.assertEqual(content_received.strip(), test_remote_file_content)
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
def test_json_file_transfer_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# get remote file to local
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
# test the received content
content_received = None
with open(self.test_local_filepath, 'r') as file:
content_received = file.read()
self.assertEqual(content_received.strip(),
test_remote_file_content.encode('utf-8').decode('utf-8'))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_no_intermediate_dir_error_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# Try to GET test file from remote
# This should raise an error with "No such file" as the directory
# does not exist
with self.assertRaises(Exception) as error:
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath_int_dir,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
self.assertIn('No such file', str(error.exception))
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_with_intermediate_dir_error_get(self):
test_remote_file_content = \
"This is remote file content \n which is also multiline " \
"another line here \n this is last line. EOF"
# create a test file remotely
create_file_task = SSHOperator(
task_id="test_create_file",
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# get remote file to local
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath_int_dir,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
create_intermediate_dirs=True,
dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
# test the received content
content_received = None
with open(self.test_local_filepath_int_dir, 'r') as file:
content_received = file.read()
self.assertEqual(content_received.strip(), test_remote_file_content)
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
def test_arg_checking(self):
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp",
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
task_0.execute(None)
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
task_1 = SFTPOperator(
task_id="test_sftp",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_1.execute(None)
except Exception: # pylint: disable=broad-except
pass
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
task_2 = SFTPOperator(
task_id="test_sftp",
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_2.execute(None)
except Exception: # pylint: disable=broad-except
pass
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
dag=self.dag
)
try:
task_3.execute(None)
except Exception: # pylint: disable=broad-except
pass
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
def delete_local_resource(self):
if os.path.exists(self.test_local_filepath):
os.remove(self.test_local_filepath)
if os.path.exists(self.test_local_filepath_int_dir):
os.remove(self.test_local_filepath_int_dir)
if os.path.exists(self.test_local_dir):
os.rmdir(self.test_local_dir)
def delete_remote_resource(self):
if os.path.exists(self.test_remote_filepath):
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
command="rm {0}".format(self.test_remote_filepath),
do_xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(remove_file_task)
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
ti3.run()
if os.path.exists(self.test_remote_filepath_int_dir):
os.remove(self.test_remote_filepath_int_dir)
if os.path.exists(self.test_remote_dir):
os.rmdir(self.test_remote_dir)
def tearDown(self):
self.delete_local_resource()
self.delete_remote_resource()
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from unittest.mock import patch
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor
class TestSFTPSensor(unittest.TestCase):
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
def test_file_present(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
context = {
'ds': '1970-01-01'
}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt')
self.assertTrue(output)
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
def test_file_absent(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
SFTP_NO_SUCH_FILE, 'File missing')
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
context = {
'ds': '1970-01-01'
}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt')
self.assertFalse(output)
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
def test_sftp_failure(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
SFTP_FAILURE, 'SFTP failure')
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
context = {
'ds': '1970-01-01'
}
with self.assertRaises(OSError):
sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt')
def test_hook_not_created_during_init(self):
sftp_sensor = SFTPSensor(
task_id='unit_test',
path='/path/to/file/1970-01-01.txt')
self.assertIsNone(sftp_sensor.hook)