[AIRFLOW-5126] Read aws_session_token in extra_config of the aws hook (#6303)

This commit is contained in:
JohannesGuenther 2019-10-16 17:22:07 +02:00 коммит произвёл Ash Berlin-Taylor
Родитель 8b0c9cbb55
Коммит d4e282d9b2
3 изменённых файлов: 41 добавлений и 9 удалений

Просмотреть файл

@ -126,17 +126,21 @@ class AwsHook(BaseHook):
external_id = extra_config.get('external_id')
aws_account_id = extra_config.get('aws_account_id')
aws_iam_role = extra_config.get('aws_iam_role')
if 'aws_session_token' in extra_config and aws_session_token is None:
aws_session_token = extra_config['aws_session_token']
if role_arn is None and aws_account_id is not None and \
aws_iam_role is not None:
if role_arn is None and aws_account_id is not None and aws_iam_role is not None:
role_arn = "arn:aws:iam::{}:role/{}" \
.format(aws_account_id, aws_iam_role)
if role_arn is not None:
sts_session = boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name)
region_name=region_name,
aws_session_token=aws_session_token
)
sts_client = sts_session.client('sts')

Просмотреть файл

@ -57,6 +57,7 @@ Extra (optional)
* ``host``: Endpoint URL for the connection
* ``region_name``: AWS region for the connection
* ``role_arn``: AWS role ARN for the connection
* ``aws_session_token``: AWS session token if you use external credentials. You are responsible for renewing these.
Example "extras" field:

Просмотреть файл

@ -17,7 +17,6 @@
# specific language governing permissions and limitations
# under the License.
#
import unittest
import boto3
@ -51,7 +50,6 @@ class TestAwsHook(unittest.TestCase):
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):
hook = AwsHook(aws_conn_id='aws_default')
resource_from_hook = hook.get_resource_type('dynamodb')
@ -113,9 +111,24 @@ class TestAwsHook(unittest.TestCase):
self.assertEqual(table.item_count, 0)
@mock.patch.object(AwsHook, 'get_connection')
def test_get_credentials_from_login(self, mock_get_connection):
def test_get_credentials_from_login_with_token(self, mock_get_connection):
mock_connection = Connection(login='aws_access_key_id',
password='aws_secret_access_key')
password='aws_secret_access_key',
extra='{"aws_session_token": "test_token"}'
)
mock_get_connection.return_value = mock_connection
hook = AwsHook()
credentials_from_hook = hook.get_credentials()
self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
self.assertEqual(credentials_from_hook.token, 'test_token')
@mock.patch.object(AwsHook, 'get_connection')
def test_get_credentials_from_login_without_token(self, mock_get_connection):
mock_connection = Connection(login='aws_access_key_id',
password='aws_secret_access_key',
)
mock_get_connection.return_value = mock_connection
hook = AwsHook()
credentials_from_hook = hook.get_credentials()
@ -124,10 +137,24 @@ class TestAwsHook(unittest.TestCase):
self.assertIsNone(credentials_from_hook.token)
@mock.patch.object(AwsHook, 'get_connection')
def test_get_credentials_from_extra(self, mock_get_connection):
def test_get_credentials_from_extra_with_token(self, mock_get_connection):
mock_connection = Connection(
extra='{"aws_access_key_id": "aws_access_key_id",'
'"aws_secret_access_key": "aws_secret_access_key"}'
'"aws_secret_access_key": "aws_secret_access_key",'
' "aws_session_token": "session_token"}'
)
mock_get_connection.return_value = mock_connection
hook = AwsHook()
credentials_from_hook = hook.get_credentials()
self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
self.assertEquals(credentials_from_hook.token, 'session_token')
@mock.patch.object(AwsHook, 'get_connection')
def test_get_credentials_from_extra_without_token(self, mock_get_connection):
mock_connection = Connection(
extra='{"aws_access_key_id": "aws_access_key_id",'
'"aws_secret_access_key": "aws_secret_access_key"}'
)
mock_get_connection.return_value = mock_connection
hook = AwsHook()